/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.ml;

import com.facebook.presto.operator.aggregation.AggregationFunction;
import com.facebook.presto.operator.aggregation.CombineFunction;
import com.facebook.presto.operator.aggregation.InputFunction;
import com.facebook.presto.operator.aggregation.OutputFunction;
import com.facebook.presto.operator.aggregation.state.AccumulatorState;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.VarcharType;
import com.facebook.presto.type.SqlType;
import com.google.common.base.Preconditions;
import java.util.Locale;

@AggregationFunction(value="evaluate_classifier_predictions")
public final class EvaluateClassifierPredictionsAggregation {
    private EvaluateClassifierPredictionsAggregation() {
    }

    @InputFunction
    public static void input(EvaluateClassifierPredictionsState state, @SqlType(value="bigint") long truth, @SqlType(value="bigint") long prediction) {
        Preconditions.checkArgument((prediction == 1L || prediction == 0L ? 1 : 0) != 0, (Object)"evaluate_predictions only supports binary classifiers");
        Preconditions.checkArgument((truth == 1L || truth == 0L ? 1 : 0) != 0, (Object)"evaluate_predictions only supports binary classifiers");
        if (truth == 1L) {
            if (prediction == 1L) {
                state.setTruePositives(state.getTruePositives() + 1L);
            } else {
                state.setFalseNegatives(state.getFalseNegatives() + 1L);
            }
        } else if (prediction == 0L) {
            state.setTrueNegatives(state.getTrueNegatives() + 1L);
        } else {
            state.setFalsePositives(state.getFalsePositives() + 1L);
        }
    }

    @CombineFunction
    public static void combine(EvaluateClassifierPredictionsState state, EvaluateClassifierPredictionsState scratchState) {
        state.setTruePositives(state.getTruePositives() + scratchState.getTruePositives());
        state.setFalsePositives(state.getFalsePositives() + scratchState.getFalsePositives());
        state.setTrueNegatives(state.getTrueNegatives() + scratchState.getTrueNegatives());
        state.setFalseNegatives(state.getFalseNegatives() + scratchState.getFalseNegatives());
    }

    @OutputFunction(value="varchar")
    public static void output(EvaluateClassifierPredictionsState state, BlockBuilder out) {
        long truePositives = state.getTruePositives();
        long falsePositives = state.getFalsePositives();
        long trueNegatives = state.getTrueNegatives();
        long falseNegatives = state.getFalseNegatives();
        StringBuilder sb = new StringBuilder();
        long correct = trueNegatives + truePositives;
        long total = truePositives + trueNegatives + falsePositives + falseNegatives;
        sb.append(String.format(Locale.US, "Accuracy: %d/%d (%.2f%%)\n", correct, total, 100.0 * (double)correct / (double)total));
        sb.append(String.format(Locale.US, "Precision: %d/%d (%.2f%%)\n", truePositives, truePositives + falsePositives, 100.0 * (double)truePositives / (double)(truePositives + falsePositives)));
        sb.append(String.format(Locale.US, "Recall: %d/%d (%.2f%%)", truePositives, truePositives + falseNegatives, 100.0 * (double)truePositives / (double)(truePositives + falseNegatives)));
        VarcharType.VARCHAR.writeString(out, sb.toString());
    }

    public static interface EvaluateClassifierPredictionsState
    extends AccumulatorState {
        public long getTruePositives();

        public void setTruePositives(long var1);

        public long getFalsePositives();

        public void setFalsePositives(long var1);

        public long getTrueNegatives();

        public void setTrueNegatives(long var1);

        public long getFalseNegatives();

        public void setFalseNegatives(long var1);
    }
}

