/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.ml;

import com.facebook.presto.operator.GroupByIdBlock;
import com.facebook.presto.operator.Page;
import com.facebook.presto.operator.aggregation.Accumulator;
import com.facebook.presto.operator.aggregation.AggregationFunction;
import com.facebook.presto.operator.aggregation.GroupedAccumulator;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.block.BlockCursor;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.VarcharType;
import com.facebook.presto.util.array.LongBigArray;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import java.util.List;

public class EvaluateClassifierPredictionsAggregation
implements AggregationFunction {
    public List<Type> getParameterTypes() {
        return ImmutableList.of((Object)BigintType.BIGINT, (Object)BigintType.BIGINT);
    }

    public Type getFinalType() {
        return VarcharType.VARCHAR;
    }

    public Type getIntermediateType() {
        return VarcharType.VARCHAR;
    }

    public boolean isDecomposable() {
        return true;
    }

    public Accumulator createAggregation(Optional<Integer> maskChannel, Optional<Integer> sampleWeight, double confidence, int ... argumentChannels) {
        Preconditions.checkArgument((!maskChannel.isPresent() ? 1 : 0) != 0, (Object)"masking is not supported");
        Preconditions.checkArgument((confidence == 1.0 ? 1 : 0) != 0, (Object)"approximation is not supported");
        Preconditions.checkArgument((!sampleWeight.isPresent() ? 1 : 0) != 0, (Object)"sample weight is not supported");
        return new EvaluatePredictionsAccumulator(argumentChannels[0], argumentChannels[1]);
    }

    public Accumulator createIntermediateAggregation(double confidence) {
        Preconditions.checkArgument((confidence == 1.0 ? 1 : 0) != 0, (Object)"approximation is not supported");
        return new EvaluatePredictionsAccumulator(-1, -1);
    }

    public GroupedAccumulator createGroupedAggregation(Optional<Integer> maskChannel, Optional<Integer> sampleWeight, double confidence, int ... argumentChannels) {
        Preconditions.checkArgument((!maskChannel.isPresent() ? 1 : 0) != 0, (Object)"masking is not supported");
        Preconditions.checkArgument((confidence == 1.0 ? 1 : 0) != 0, (Object)"approximation is not supported");
        Preconditions.checkArgument((!sampleWeight.isPresent() ? 1 : 0) != 0, (Object)"sample weight is not supported");
        return new EvaluatePredictionsGroupedAccumulator(argumentChannels[0], argumentChannels[1]);
    }

    public GroupedAccumulator createGroupedIntermediateAggregation(double confidence) {
        Preconditions.checkArgument((confidence == 1.0 ? 1 : 0) != 0, (Object)"approximation is not supported");
        return new EvaluatePredictionsGroupedAccumulator(-1, -1);
    }

    public static String formatResults(long truePositives, long falsePositives, long trueNegatives, long falseNegatives) {
        StringBuilder sb = new StringBuilder();
        long correct = trueNegatives + truePositives;
        long total = truePositives + trueNegatives + falsePositives + falseNegatives;
        sb.append(String.format("Accuracy: %d/%d (%.2f%%), ", correct, total, 100.0 * (double)correct / (double)total));
        sb.append(String.format("Precision: %d/%d (%.2f%%), ", truePositives, truePositives + falsePositives, 100.0 * (double)truePositives / (double)(truePositives + falsePositives)));
        sb.append(String.format("Recall: %d/%d (%.2f%%)", truePositives, truePositives + falseNegatives, 100.0 * (double)truePositives / (double)(truePositives + falseNegatives)));
        return sb.toString();
    }

    public static Slice createIntermediate(long truePositives, long falsePositives, long trueNegatives, long falseNegatives) {
        Slice slice = Slices.allocate((int)32);
        slice.setLong(0, truePositives);
        slice.setLong(8, falsePositives);
        slice.setLong(16, trueNegatives);
        slice.setLong(24, falseNegatives);
        return slice;
    }

    public static class EvaluatePredictionsGroupedAccumulator
    implements GroupedAccumulator {
        private final int labelChannel;
        private final int predictionChannel;
        private final LongBigArray truePositives = new LongBigArray();
        private final LongBigArray falsePositives = new LongBigArray();
        private final LongBigArray trueNegatives = new LongBigArray();
        private final LongBigArray falseNegatives = new LongBigArray();

        public EvaluatePredictionsGroupedAccumulator(int labelChannel, int predictionChannel) {
            this.labelChannel = labelChannel;
            this.predictionChannel = predictionChannel;
        }

        public Type getFinalType() {
            return VarcharType.VARCHAR;
        }

        public Type getIntermediateType() {
            return VarcharType.VARCHAR;
        }

        public long getEstimatedSize() {
            return this.truePositives.sizeOf() + this.falsePositives.sizeOf() + this.trueNegatives.sizeOf() + this.falseNegatives.sizeOf();
        }

        public void addInput(GroupByIdBlock groupIdsBlock, Page page) {
            this.truePositives.ensureCapacity(groupIdsBlock.getGroupCount());
            this.falsePositives.ensureCapacity(groupIdsBlock.getGroupCount());
            this.trueNegatives.ensureCapacity(groupIdsBlock.getGroupCount());
            this.falseNegatives.ensureCapacity(groupIdsBlock.getGroupCount());
            BlockCursor labelCursor = page.getBlock(this.labelChannel).cursor();
            BlockCursor predictionCursor = page.getBlock(this.predictionChannel).cursor();
            for (int position = 0; position < groupIdsBlock.getPositionCount(); ++position) {
                long groupId = groupIdsBlock.getGroupId(position);
                Preconditions.checkState((boolean)labelCursor.advanceNextPosition());
                Preconditions.checkState((boolean)predictionCursor.advanceNextPosition());
                long predicted = predictionCursor.getLong();
                long label = labelCursor.getLong();
                Preconditions.checkArgument((predicted == 1L || predicted == 0L ? 1 : 0) != 0, (Object)"evaluate_predictions only supports binary classifiers");
                Preconditions.checkArgument((label == 1L || label == 0L ? 1 : 0) != 0, (Object)"evaluate_predictions only supports binary classifiers");
                if (label == 1L) {
                    if (predicted == 1L) {
                        this.truePositives.increment(groupId);
                        continue;
                    }
                    this.falseNegatives.increment(groupId);
                    continue;
                }
                if (predicted == 0L) {
                    this.trueNegatives.increment(groupId);
                    continue;
                }
                this.falsePositives.increment(groupId);
            }
        }

        public void addIntermediate(GroupByIdBlock groupIdsBlock, Block block) {
            this.truePositives.ensureCapacity(groupIdsBlock.getGroupCount());
            this.falsePositives.ensureCapacity(groupIdsBlock.getGroupCount());
            this.trueNegatives.ensureCapacity(groupIdsBlock.getGroupCount());
            this.falseNegatives.ensureCapacity(groupIdsBlock.getGroupCount());
            BlockCursor cursor = block.cursor();
            for (int position = 0; position < groupIdsBlock.getPositionCount(); ++position) {
                Preconditions.checkState((boolean)cursor.advanceNextPosition());
                long groupId = groupIdsBlock.getGroupId(position);
                Slice slice = cursor.getSlice();
                this.truePositives.add(groupId, slice.getLong(0));
                this.falsePositives.add(groupId, slice.getLong(8));
                this.trueNegatives.add(groupId, slice.getLong(16));
                this.falseNegatives.add(groupId, slice.getLong(24));
            }
            Preconditions.checkState((!cursor.advanceNextPosition() ? 1 : 0) != 0);
        }

        public void evaluateIntermediate(int groupId, BlockBuilder output) {
            output.appendSlice(EvaluateClassifierPredictionsAggregation.createIntermediate(this.truePositives.get((long)groupId), this.falsePositives.get((long)groupId), this.trueNegatives.get((long)groupId), this.falseNegatives.get((long)groupId))).build();
        }

        public void evaluateFinal(int groupId, BlockBuilder output) {
            output.appendSlice(Slices.utf8Slice((String)EvaluateClassifierPredictionsAggregation.formatResults(this.truePositives.get((long)groupId), this.falsePositives.get((long)groupId), this.trueNegatives.get((long)groupId), this.falseNegatives.get((long)groupId))));
        }
    }

    public static class EvaluatePredictionsAccumulator
    implements Accumulator {
        private final int labelChannel;
        private final int predictionChannel;
        private long truePositives;
        private long falsePositives;
        private long trueNegatives;
        private long falseNegatives;

        public EvaluatePredictionsAccumulator(int labelChannel, int predictionChannel) {
            this.labelChannel = labelChannel;
            this.predictionChannel = predictionChannel;
        }

        public long getEstimatedSize() {
            return 0L;
        }

        public Type getFinalType() {
            return VarcharType.VARCHAR;
        }

        public Type getIntermediateType() {
            return VarcharType.VARCHAR;
        }

        public void addInput(Page page) {
            BlockCursor labelCursor = page.getBlock(this.labelChannel).cursor();
            BlockCursor predictionCursor = page.getBlock(this.predictionChannel).cursor();
            while (labelCursor.advanceNextPosition()) {
                Preconditions.checkState((boolean)predictionCursor.advanceNextPosition());
                long predicted = predictionCursor.getLong();
                long label = labelCursor.getLong();
                Preconditions.checkArgument((predicted == 1L || predicted == 0L ? 1 : 0) != 0, (Object)"evaluate_predictions only supports binary classifiers");
                Preconditions.checkArgument((label == 1L || label == 0L ? 1 : 0) != 0, (Object)"evaluate_predictions only supports binary classifiers");
                if (label == 1L) {
                    if (predicted == 1L) {
                        ++this.truePositives;
                        continue;
                    }
                    ++this.falseNegatives;
                    continue;
                }
                if (predicted == 0L) {
                    ++this.trueNegatives;
                    continue;
                }
                ++this.falsePositives;
            }
        }

        public void addIntermediate(Block block) {
            BlockCursor cursor = block.cursor();
            Preconditions.checkState((boolean)cursor.advanceNextPosition());
            Slice slice = cursor.getSlice();
            Preconditions.checkState((!cursor.advanceNextPosition() ? 1 : 0) != 0);
            this.truePositives += slice.getLong(0);
            this.falsePositives += slice.getLong(8);
            this.trueNegatives += slice.getLong(16);
            this.falseNegatives += slice.getLong(24);
        }

        public Block evaluateIntermediate() {
            BlockBuilder builder = this.getIntermediateType().createBlockBuilder(new BlockBuilderStatus());
            return builder.appendSlice(EvaluateClassifierPredictionsAggregation.createIntermediate(this.truePositives, this.falsePositives, this.trueNegatives, this.falseNegatives)).build();
        }

        public Block evaluateFinal() {
            StringBuilder sb = new StringBuilder();
            long correct = this.trueNegatives + this.truePositives;
            long total = this.truePositives + this.trueNegatives + this.falsePositives + this.falseNegatives;
            sb.append(String.format("Accuracy: %d/%d (%.2f%%)\n", correct, total, 100.0 * (double)correct / (double)total));
            sb.append(String.format("Precision: %d/%d (%.2f%%)\n", this.truePositives, this.truePositives + this.falsePositives, 100.0 * (double)this.truePositives / (double)(this.truePositives + this.falsePositives)));
            sb.append(String.format("Recall: %d/%d (%.2f%%)", this.truePositives, this.truePositives + this.falseNegatives, 100.0 * (double)this.truePositives / (double)(this.truePositives + this.falseNegatives)));
            BlockBuilder builder = this.getFinalType().createBlockBuilder(new BlockBuilderStatus());
            builder.appendSlice(Slices.utf8Slice((String)sb.toString()));
            return builder.build();
        }
    }
}

