/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.ml;

import com.facebook.presto.ml.FeatureVector;
import com.facebook.presto.ml.LearnState;
import com.facebook.presto.operator.aggregation.state.AbstractGroupedAccumulatorState;
import com.facebook.presto.operator.aggregation.state.AccumulatorStateFactory;
import com.facebook.presto.util.array.ObjectBigArray;
import com.facebook.presto.util.array.SliceBigArray;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import io.airlift.slice.Slice;
import java.util.ArrayList;
import java.util.List;
import libsvm.svm_parameter;
import org.openjdk.jol.info.ClassLayout;

public class LearnStateFactory
implements AccumulatorStateFactory<LearnState> {
    private static final long ARRAY_LIST_SIZE = ClassLayout.parseClass(ArrayList.class).instanceSize();
    private static final long SVM_PARAMETERS_SIZE = ClassLayout.parseClass(svm_parameter.class).instanceSize();

    public LearnState createSingleState() {
        return new SingleLearnState();
    }

    public Class<? extends LearnState> getSingleStateClass() {
        return SingleLearnState.class;
    }

    public LearnState createGroupedState() {
        return new GroupedLearnState();
    }

    public Class<? extends LearnState> getGroupedStateClass() {
        return GroupedLearnState.class;
    }

    public static class SingleLearnState
    implements LearnState {
        private final List<Double> labels = new ArrayList<Double>();
        private final List<FeatureVector> featureVectors = new ArrayList<FeatureVector>();
        private final BiMap<String, Integer> labelEnumeration = HashBiMap.create();
        private int nextLabel;
        private Slice parameters;
        private long size;

        public long getEstimatedSize() {
            return this.size + 2L * ARRAY_LIST_SIZE;
        }

        @Override
        public BiMap<String, Integer> getLabelEnumeration() {
            return this.labelEnumeration;
        }

        @Override
        public int enumerateLabel(String label) {
            if (!this.labelEnumeration.containsKey((Object)label)) {
                this.labelEnumeration.put((Object)label, (Object)this.nextLabel);
                ++this.nextLabel;
            }
            return (Integer)this.labelEnumeration.get((Object)label);
        }

        @Override
        public List<Double> getLabels() {
            return this.labels;
        }

        @Override
        public List<FeatureVector> getFeatureVectors() {
            return this.featureVectors;
        }

        @Override
        public Slice getParameters() {
            return this.parameters;
        }

        @Override
        public void setParameters(Slice parameters) {
            this.parameters = parameters;
        }

        @Override
        public void addMemoryUsage(long value) {
            this.size += value;
        }
    }

    public static class GroupedLearnState
    extends AbstractGroupedAccumulatorState
    implements LearnState {
        private final ObjectBigArray<List<Double>> labelsArray = new ObjectBigArray();
        private final ObjectBigArray<List<FeatureVector>> featureVectorsArray = new ObjectBigArray();
        private final SliceBigArray parametersArray = new SliceBigArray();
        private final BiMap<String, Integer> labelEnumeration = HashBiMap.create();
        private int nextLabel;
        private long size;

        public void ensureCapacity(long size) {
            this.labelsArray.ensureCapacity(size);
            this.featureVectorsArray.ensureCapacity(size);
            this.parametersArray.ensureCapacity(size);
        }

        public long getEstimatedSize() {
            return this.size + this.labelsArray.sizeOf() + this.featureVectorsArray.sizeOf();
        }

        @Override
        public BiMap<String, Integer> getLabelEnumeration() {
            return this.labelEnumeration;
        }

        @Override
        public int enumerateLabel(String label) {
            if (!this.labelEnumeration.containsKey((Object)label)) {
                this.labelEnumeration.put((Object)label, (Object)this.nextLabel);
                ++this.nextLabel;
            }
            return (Integer)this.labelEnumeration.get((Object)label);
        }

        @Override
        public List<Double> getLabels() {
            ArrayList labels = (ArrayList)this.labelsArray.get(this.getGroupId());
            if (labels == null) {
                labels = new ArrayList();
                this.size += ARRAY_LIST_SIZE;
                this.size += SVM_PARAMETERS_SIZE;
                this.labelsArray.set(this.getGroupId(), labels);
            }
            return labels;
        }

        @Override
        public List<FeatureVector> getFeatureVectors() {
            ArrayList featureVectors = (ArrayList)this.featureVectorsArray.get(this.getGroupId());
            if (featureVectors == null) {
                featureVectors = new ArrayList();
                this.size += ARRAY_LIST_SIZE;
                this.featureVectorsArray.set(this.getGroupId(), featureVectors);
            }
            return featureVectors;
        }

        @Override
        public Slice getParameters() {
            return this.parametersArray.get(this.getGroupId());
        }

        @Override
        public void setParameters(Slice parameters) {
            this.parametersArray.set(this.getGroupId(), parameters);
        }

        @Override
        public void addMemoryUsage(long value) {
            this.size += value;
        }
    }
}

