/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.ml;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import io.trino.array.ObjectBigArray;
import io.trino.array.SliceBigArray;
import io.trino.plugin.ml.FeatureVector;
import io.trino.plugin.ml.LearnState;
import io.trino.spi.function.AccumulatorStateFactory;
import io.trino.spi.function.GroupedAccumulatorState;
import java.util.ArrayList;
import java.util.List;
import libsvm.svm_parameter;

public class LearnStateFactory
implements AccumulatorStateFactory<LearnState> {
    private static final long ARRAY_LIST_SIZE = SizeOf.instanceSize(ArrayList.class);
    private static final long SVM_PARAMETERS_SIZE = SizeOf.instanceSize(svm_parameter.class);

    public LearnState createSingleState() {
        return new SingleLearnState();
    }

    public LearnState createGroupedState() {
        return new GroupedLearnState();
    }

    public static class SingleLearnState
    implements LearnState {
        private final List<Double> labels = new ArrayList<Double>();
        private final List<FeatureVector> featureVectors = new ArrayList<FeatureVector>();
        private final BiMap<String, Integer> labelEnumeration = HashBiMap.create();
        private int nextLabel;
        private Slice parameters;
        private long size;

        public long getEstimatedSize() {
            return this.size + 2L * ARRAY_LIST_SIZE;
        }

        @Override
        public BiMap<String, Integer> getLabelEnumeration() {
            return this.labelEnumeration;
        }

        @Override
        public int enumerateLabel(String label) {
            if (!this.labelEnumeration.containsKey((Object)label)) {
                this.labelEnumeration.put((Object)label, (Object)this.nextLabel);
                ++this.nextLabel;
            }
            return (Integer)this.labelEnumeration.get((Object)label);
        }

        @Override
        public List<Double> getLabels() {
            return this.labels;
        }

        @Override
        public List<FeatureVector> getFeatureVectors() {
            return this.featureVectors;
        }

        @Override
        public Slice getParameters() {
            return this.parameters;
        }

        @Override
        public void setParameters(Slice parameters) {
            this.parameters = parameters;
        }

        @Override
        public void addMemoryUsage(long value) {
            this.size += value;
        }
    }

    public static class GroupedLearnState
    implements GroupedAccumulatorState,
    LearnState {
        private final ObjectBigArray<List<Double>> labelsArray = new ObjectBigArray();
        private final ObjectBigArray<List<FeatureVector>> featureVectorsArray = new ObjectBigArray();
        private final SliceBigArray parametersArray = new SliceBigArray();
        private final BiMap<String, Integer> labelEnumeration = HashBiMap.create();
        private long groupId;
        private int nextLabel;
        private long size;

        public void setGroupId(long groupId) {
            this.groupId = groupId;
        }

        public void ensureCapacity(long size) {
            this.labelsArray.ensureCapacity(size);
            this.featureVectorsArray.ensureCapacity(size);
            this.parametersArray.ensureCapacity(size);
        }

        public long getEstimatedSize() {
            return this.size + this.labelsArray.sizeOf() + this.featureVectorsArray.sizeOf();
        }

        @Override
        public BiMap<String, Integer> getLabelEnumeration() {
            return this.labelEnumeration;
        }

        @Override
        public int enumerateLabel(String label) {
            if (!this.labelEnumeration.containsKey((Object)label)) {
                this.labelEnumeration.put((Object)label, (Object)this.nextLabel);
                ++this.nextLabel;
            }
            return (Integer)this.labelEnumeration.get((Object)label);
        }

        @Override
        public List<Double> getLabels() {
            ArrayList labels = (ArrayList)this.labelsArray.get(this.groupId);
            if (labels == null) {
                labels = new ArrayList();
                this.size += ARRAY_LIST_SIZE;
                this.size += SVM_PARAMETERS_SIZE;
                this.labelsArray.set(this.groupId, labels);
            }
            return labels;
        }

        @Override
        public List<FeatureVector> getFeatureVectors() {
            ArrayList featureVectors = (ArrayList)this.featureVectorsArray.get(this.groupId);
            if (featureVectors == null) {
                featureVectors = new ArrayList();
                this.size += ARRAY_LIST_SIZE;
                this.featureVectorsArray.set(this.groupId, featureVectors);
            }
            return featureVectors;
        }

        @Override
        public Slice getParameters() {
            return this.parametersArray.get(this.groupId);
        }

        @Override
        public void setParameters(Slice parameters) {
            this.parametersArray.set(this.groupId, parameters);
        }

        @Override
        public void addMemoryUsage(long value) {
            this.size += value;
        }
    }
}

