/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.ml;

import com.facebook.presto.ml.ClassifierFeatureTransformer;
import com.facebook.presto.ml.Dataset;
import com.facebook.presto.ml.FeatureUnitNormalizer;
import com.facebook.presto.ml.FeatureVector;
import com.facebook.presto.ml.Model;
import com.facebook.presto.ml.ModelUtils;
import com.facebook.presto.ml.RegressorFeatureTransformer;
import com.facebook.presto.ml.SvmClassifier;
import com.facebook.presto.ml.SvmRegressor;
import com.facebook.presto.ml.type.ClassifierType;
import com.facebook.presto.ml.type.RegressorType;
import com.facebook.presto.operator.aggregation.Accumulator;
import com.facebook.presto.operator.aggregation.AccumulatorFactory;
import com.facebook.presto.operator.aggregation.GroupedAccumulator;
import com.facebook.presto.operator.aggregation.InternalAggregationFunction;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.DoubleType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.VarcharType;
import com.facebook.presto.type.UnknownType;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

public class LearnAggregation
implements InternalAggregationFunction {
    private final Type modelType;
    private final Type labelType;

    public LearnAggregation(Type modelType, Type labelType) {
        this.modelType = modelType;
        this.labelType = labelType;
    }

    public String name() {
        return this.modelType == ClassifierType.CLASSIFIER ? "learn_classifier" : "learn_regressor";
    }

    public List<Type> getParameterTypes() {
        return ImmutableList.of((Object)this.labelType, (Object)VarcharType.VARCHAR);
    }

    public Type getFinalType() {
        return this.modelType;
    }

    public Type getIntermediateType() {
        return UnknownType.UNKNOWN;
    }

    public boolean isDecomposable() {
        return false;
    }

    public boolean isApproximate() {
        return false;
    }

    public AccumulatorFactory bind(List<Integer> inputChannels, Optional<Integer> maskChannel, Optional<Integer> sampleWeightChannel, double confidence) {
        Preconditions.checkArgument((!maskChannel.isPresent() ? 1 : 0) != 0, (Object)"masking is not supported");
        Preconditions.checkArgument((confidence == 1.0 ? 1 : 0) != 0, (Object)"approximation is not supported");
        Preconditions.checkArgument((!sampleWeightChannel.isPresent() ? 1 : 0) != 0, (Object)"sample weight is not supported");
        return new LearnAccumulatorFactory(inputChannels, this.labelType == BigintType.BIGINT, this.modelType == RegressorType.REGRESSOR);
    }

    public static class LearnAccumulatorFactory
    implements AccumulatorFactory {
        private final List<Integer> inputChannels;
        private final boolean labelIsLong;
        private final boolean regression;

        public LearnAccumulatorFactory(List<Integer> inputChannels, boolean labelIsLong, boolean regression) {
            this.inputChannels = ImmutableList.copyOf((Collection)((Collection)Preconditions.checkNotNull(inputChannels, (Object)"inputChannels is null")));
            this.labelIsLong = labelIsLong;
            this.regression = regression;
        }

        public List<Integer> getInputChannels() {
            return this.inputChannels;
        }

        public Accumulator createAccumulator() {
            return new LearnAccumulator(this.inputChannels.get(0), this.inputChannels.get(1), this.labelIsLong, this.regression);
        }

        public Accumulator createIntermediateAccumulator() {
            throw new UnsupportedOperationException("LEARN must run on a single machine");
        }

        public GroupedAccumulator createGroupedAccumulator() {
            throw new UnsupportedOperationException("LEARN doesn't support GROUP BY");
        }

        public GroupedAccumulator createGroupedIntermediateAccumulator() {
            throw new UnsupportedOperationException("LEARN doesn't support GROUP BY");
        }

        public static class LearnAccumulator
        implements Accumulator {
            private final int labelChannel;
            private final int featuresChannel;
            private final boolean labelIsLong;
            private final boolean regression;
            private final List<Double> labels = new ArrayList<Double>();
            private final List<FeatureVector> rows = new ArrayList<FeatureVector>();
            private long rowsSize;

            public LearnAccumulator(int labelChannel, int featuresChannel, boolean labelIsLong, boolean regression) {
                this.labelChannel = labelChannel;
                this.featuresChannel = featuresChannel;
                this.labelIsLong = labelIsLong;
                this.regression = regression;
            }

            public long getEstimatedSize() {
                return 8L * (long)this.labels.size() + this.rowsSize;
            }

            public Type getFinalType() {
                return VarcharType.VARCHAR;
            }

            public Type getIntermediateType() {
                throw new UnsupportedOperationException("LEARN must run on a single machine");
            }

            public void addInput(Page page) {
                int position;
                Block block = page.getBlock(this.labelChannel);
                for (position = 0; position < block.getPositionCount(); ++position) {
                    if (this.labelIsLong) {
                        this.labels.add(Double.valueOf(BigintType.BIGINT.getLong(block, position)));
                        continue;
                    }
                    this.labels.add(DoubleType.DOUBLE.getDouble(block, position));
                }
                block = page.getBlock(this.featuresChannel);
                for (position = 0; position < block.getPositionCount(); ++position) {
                    FeatureVector featureVector = ModelUtils.jsonToFeatures(VarcharType.VARCHAR.getSlice(block, position));
                    this.rowsSize += featureVector.getEstimatedSize();
                    this.rows.add(featureVector);
                }
            }

            public void addIntermediate(Block block) {
                throw new UnsupportedOperationException("LEARN must run on a single machine");
            }

            public Block evaluateIntermediate() {
                throw new UnsupportedOperationException("LEARN must run on a single machine");
            }

            public Block evaluateFinal() {
                Dataset dataset = new Dataset(this.labels, this.rows);
                Model model = this.regression ? new RegressorFeatureTransformer(new SvmRegressor(), new FeatureUnitNormalizer()) : new ClassifierFeatureTransformer(new SvmClassifier(), new FeatureUnitNormalizer());
                model.train(dataset);
                BlockBuilder builder = this.getFinalType().createBlockBuilder(new BlockBuilderStatus());
                this.getFinalType().writeSlice(builder, ModelUtils.serialize(model));
                return builder.build();
            }
        }
    }
}

