/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.ml;

import com.facebook.presto.ml.ClassifierFeatureTransformer;
import com.facebook.presto.ml.Dataset;
import com.facebook.presto.ml.FeatureVector;
import com.facebook.presto.ml.FeatureVectorUnitNormalizer;
import com.facebook.presto.ml.Model;
import com.facebook.presto.ml.ModelUtils;
import com.facebook.presto.ml.RegressorFeatureTransformer;
import com.facebook.presto.ml.SvmClassifier;
import com.facebook.presto.ml.SvmRegressor;
import com.facebook.presto.ml.type.RegressorType;
import com.facebook.presto.operator.Page;
import com.facebook.presto.operator.aggregation.Accumulator;
import com.facebook.presto.operator.aggregation.AggregationFunction;
import com.facebook.presto.operator.aggregation.GroupedAccumulator;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.block.BlockCursor;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.VarcharType;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;

public class LearnAggregation
implements AggregationFunction {
    private final Type modelType;
    private final Type labelType;

    public LearnAggregation(Type modelType, Type labelType) {
        this.modelType = modelType;
        this.labelType = labelType;
    }

    public List<Type> getParameterTypes() {
        return ImmutableList.of((Object)this.labelType, (Object)VarcharType.VARCHAR);
    }

    public Type getFinalType() {
        return this.modelType;
    }

    public Type getIntermediateType() {
        throw new UnsupportedOperationException("LEARN must run on a single machine");
    }

    public boolean isDecomposable() {
        return false;
    }

    public Accumulator createAggregation(Optional<Integer> maskChannel, Optional<Integer> sampleWeight, double confidence, int ... argumentChannels) {
        Preconditions.checkArgument((!maskChannel.isPresent() ? 1 : 0) != 0, (Object)"masking is not supported");
        Preconditions.checkArgument((confidence == 1.0 ? 1 : 0) != 0, (Object)"approximation is not supported");
        Preconditions.checkArgument((!sampleWeight.isPresent() ? 1 : 0) != 0, (Object)"sample weight is not supported");
        return new LearnAccumulator(argumentChannels[0], argumentChannels[1], this.labelType == BigintType.BIGINT, this.modelType == RegressorType.REGRESSOR);
    }

    public Accumulator createIntermediateAggregation(double confidence) {
        throw new UnsupportedOperationException("LEARN must run on a single machine");
    }

    public GroupedAccumulator createGroupedAggregation(Optional<Integer> maskChannel, Optional<Integer> sampleWeight, double confidence, int ... argumentChannels) {
        throw new UnsupportedOperationException("LEARN doesn't support GROUP BY");
    }

    public GroupedAccumulator createGroupedIntermediateAggregation(double confidence) {
        throw new UnsupportedOperationException("LEARN doesn't support GROUP BY");
    }

    public static class LearnAccumulator
    implements Accumulator {
        private final int labelChannel;
        private final int featuresChannel;
        private final boolean labelIsLong;
        private final boolean regression;
        private final List<Double> labels = new ArrayList<Double>();
        private final List<FeatureVector> rows = new ArrayList<FeatureVector>();
        private long rowsSize;

        public LearnAccumulator(int labelChannel, int featuresChannel, boolean labelIsLong, boolean regression) {
            this.labelChannel = labelChannel;
            this.featuresChannel = featuresChannel;
            this.labelIsLong = labelIsLong;
            this.regression = regression;
        }

        public long getEstimatedSize() {
            return 8L * (long)this.labels.size() + this.rowsSize;
        }

        public Type getFinalType() {
            return VarcharType.VARCHAR;
        }

        public Type getIntermediateType() {
            throw new UnsupportedOperationException("LEARN must run on a single machine");
        }

        public void addInput(Page page) {
            BlockCursor cursor = page.getBlock(this.labelChannel).cursor();
            while (cursor.advanceNextPosition()) {
                if (this.labelIsLong) {
                    this.labels.add(Double.valueOf(cursor.getLong()));
                    continue;
                }
                this.labels.add(cursor.getDouble());
            }
            cursor = page.getBlock(this.featuresChannel).cursor();
            while (cursor.advanceNextPosition()) {
                FeatureVector featureVector = ModelUtils.jsonToFeatures(cursor.getSlice());
                this.rowsSize += featureVector.getEstimatedSize();
                this.rows.add(featureVector);
            }
        }

        public void addIntermediate(Block block) {
            throw new UnsupportedOperationException("LEARN must run on a single machine");
        }

        public Block evaluateIntermediate() {
            throw new UnsupportedOperationException("LEARN must run on a single machine");
        }

        public Block evaluateFinal() {
            Dataset dataset = new Dataset(this.labels, this.rows);
            Model model = this.regression ? new RegressorFeatureTransformer(new SvmRegressor(), new FeatureVectorUnitNormalizer()) : new ClassifierFeatureTransformer(new SvmClassifier(), new FeatureVectorUnitNormalizer());
            model.train(dataset);
            BlockBuilder builder = this.getFinalType().createBlockBuilder(new BlockBuilderStatus());
            builder.appendSlice(ModelUtils.serialize(model));
            return builder.build();
        }
    }
}

