/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.eval.BaseEvaluation;
import org.deeplearning4j.eval.ROC;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;

public class ROCMultiClass
extends BaseEvaluation<ROCMultiClass> {
    private final int thresholdSteps;
    private long[] countActualPositive;
    private long[] countActualNegative;
    private final Map<Integer, Map<Double, ROC.CountsForThreshold>> counts = new LinkedHashMap<Integer, Map<Double, ROC.CountsForThreshold>>();

    public ROCMultiClass(int thresholdSteps) {
        this.thresholdSteps = thresholdSteps;
    }

    @Override
    public void eval(INDArray labels, INDArray predictions) {
        if (labels.rank() == 3 && predictions.rank() == 3) {
            this.evalTimeSeries(labels, predictions);
        }
        if (labels.rank() > 2 || predictions.rank() > 2 || labels.size(1) != predictions.size(1)) {
            throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(labels.shape()) + ", predictions shape = " + Arrays.toString(predictions.shape()) + "; require rank 2 array with size(1) == 1 or 2");
        }
        double step = 1.0 / (double)this.thresholdSteps;
        if (this.countActualPositive == null) {
            int size = labels.size(1);
            this.countActualPositive = new long[size];
            this.countActualNegative = new long[size];
            for (int i = 0; i < size; ++i) {
                LinkedHashMap<Double, ROC.CountsForThreshold> map = new LinkedHashMap<Double, ROC.CountsForThreshold>();
                this.counts.put(i, map);
                for (int j = 0; j <= this.thresholdSteps; ++j) {
                    double currThreshold = (double)j * step;
                    map.put(currThreshold, new ROC.CountsForThreshold(currThreshold));
                }
            }
        }
        if (this.countActualPositive.length != labels.size(1)) {
            throw new IllegalArgumentException("Cannot evaluate data: number of label classes does not match previous call. Got " + labels.size(1) + " labels (from array shape " + Arrays.toString(labels.shape()) + ") vs. expected number of label classes = " + this.countActualPositive.length);
        }
        for (int i = 0; i < this.countActualPositive.length; ++i) {
            INDArray positiveActualColumn = labels.getColumn(i);
            INDArray positivePredictedColumn = predictions.getColumn(i);
            long currBatchPositiveActualCount = positiveActualColumn.sumNumber().intValue();
            int n = i;
            this.countActualPositive[n] = this.countActualPositive[n] + currBatchPositiveActualCount;
            int n2 = i;
            this.countActualNegative[n2] = this.countActualNegative[n2] + ((long)positiveActualColumn.length() - currBatchPositiveActualCount);
            for (int j = 0; j <= this.thresholdSteps; ++j) {
                double currThreshold = (double)j * step;
                Condition condGeq = Conditions.greaterThanOrEqual((Number)currThreshold);
                Condition condLeq = Conditions.lessThanOrEqual((Number)currThreshold);
                CompareAndSet op = new CompareAndSet(positivePredictedColumn.dup(), 1.0, condGeq);
                INDArray predictedClass1 = Nd4j.getExecutioner().execAndReturn((Op)op);
                op = new CompareAndSet(predictedClass1, 0.0, condLeq);
                predictedClass1 = Nd4j.getExecutioner().execAndReturn((Op)op);
                INDArray isTruePositive = predictedClass1.mul(positiveActualColumn);
                INDArray negativeActualColumn = positiveActualColumn.rsub((Number)1.0);
                INDArray isFalsePositive = predictedClass1.mul(negativeActualColumn);
                int truePositiveCount = isTruePositive.sumNumber().intValue();
                int falsePositiveCount = isFalsePositive.sumNumber().intValue();
                ROC.CountsForThreshold thresholdCounts = this.counts.get(i).get(currThreshold);
                thresholdCounts.incrementTruePositive(truePositiveCount);
                thresholdCounts.incrementFalsePositive(falsePositiveCount);
            }
        }
    }

    public List<ROC.ROCValue> getResults(int classIdx) {
        this.assertHasBeenFit(classIdx);
        ArrayList<ROC.ROCValue> out = new ArrayList<ROC.ROCValue>(this.counts.size());
        for (Map.Entry<Double, ROC.CountsForThreshold> entry : this.counts.get(classIdx).entrySet()) {
            double t = entry.getKey();
            ROC.CountsForThreshold c = entry.getValue();
            double tpr = (double)c.getCountTruePositive() / (double)this.countActualPositive[classIdx];
            double fpr = (double)c.getCountFalsePositive() / (double)this.countActualNegative[classIdx];
            out.add(new ROC.ROCValue(t, tpr, fpr));
        }
        return out;
    }

    public double[][] getResultsAsArray(int classIdx) {
        this.assertHasBeenFit(classIdx);
        double[][] out = new double[2][this.thresholdSteps + 1];
        int i = 0;
        for (Map.Entry<Double, ROC.CountsForThreshold> entry : this.counts.get(classIdx).entrySet()) {
            double fpr;
            ROC.CountsForThreshold c = entry.getValue();
            double tpr = (double)c.getCountTruePositive() / (double)this.countActualPositive[classIdx];
            out[0][i] = fpr = (double)c.getCountFalsePositive() / (double)this.countActualNegative[classIdx];
            out[1][i] = tpr;
            ++i;
        }
        return out;
    }

    public double calculateAUC(int classIdx) {
        this.assertHasBeenFit(classIdx);
        List<ROC.ROCValue> list = this.getResults(classIdx);
        double auc = 0.0;
        for (int i = 0; i < list.size() - 1; ++i) {
            ROC.ROCValue left = list.get(i);
            ROC.ROCValue right = list.get(i + 1);
            double deltaX = Math.abs(right.getFalsePositiveRate() - left.getFalsePositiveRate());
            double avg = (left.getTruePositiveRate() + right.getTruePositiveRate()) / 2.0;
            auc += deltaX * avg;
        }
        return auc;
    }

    public double calculateAverageAUC() {
        this.assertHasBeenFit(0);
        double sum = 0.0;
        for (int i = 0; i < this.countActualPositive.length; ++i) {
            sum += this.calculateAUC(i);
        }
        return sum / (double)this.countActualPositive.length;
    }

    public List<ROC.PrecisionRecallPoint> getPrecisionRecallCurve(int classIndex) {
        ArrayList<ROC.PrecisionRecallPoint> out = new ArrayList<ROC.PrecisionRecallPoint>(this.counts.get(classIndex).size());
        for (Map.Entry<Double, ROC.CountsForThreshold> entry : this.counts.get(classIndex).entrySet()) {
            double t = entry.getKey();
            ROC.CountsForThreshold c = entry.getValue();
            long tpCount = c.getCountTruePositive();
            long fpCount = c.getCountFalsePositive();
            double precision = tpCount == 0L && fpCount == 0L ? 1.0 : (double)tpCount / (double)(tpCount + fpCount);
            double recall = this.countActualPositive[classIndex] == 0L ? 1.0 : (double)tpCount / (double)this.countActualPositive[classIndex];
            out.add(new ROC.PrecisionRecallPoint(c.getThreshold(), precision, recall));
        }
        return out;
    }

    @Override
    public void merge(ROCMultiClass other) {
        if (other.countActualPositive == null) {
            return;
        }
        if (this.countActualPositive == null) {
            this.countActualPositive = Arrays.copyOf(other.countActualPositive, other.countActualPositive.length);
            this.countActualNegative = Arrays.copyOf(other.countActualNegative, other.countActualNegative.length);
            for (Map.Entry<Integer, Map<Double, ROC.CountsForThreshold>> e : other.counts.entrySet()) {
                Map<Double, ROC.CountsForThreshold> m = e.getValue();
                LinkedHashMap<Double, ROC.CountsForThreshold> mClone = new LinkedHashMap<Double, ROC.CountsForThreshold>();
                for (Map.Entry<Double, ROC.CountsForThreshold> e2 : m.entrySet()) {
                    mClone.put(e2.getKey(), e2.getValue().clone());
                }
                this.counts.put(e.getKey(), mClone);
            }
        } else {
            for (int i = 0; i < this.countActualPositive.length; ++i) {
                int n = i;
                this.countActualPositive[n] = this.countActualPositive[n] + other.countActualPositive[i];
                int n2 = i;
                this.countActualNegative[n2] = this.countActualNegative[n2] + other.countActualNegative[i];
            }
            for (Integer i : this.counts.keySet()) {
                Map<Double, ROC.CountsForThreshold> thisMap = this.counts.get(i);
                Map<Double, ROC.CountsForThreshold> otherMap = other.counts.get(i);
                for (Double d : thisMap.keySet()) {
                    ROC.CountsForThreshold thisC = thisMap.get(d);
                    ROC.CountsForThreshold otherC = otherMap.get(d);
                    thisC.incrementTruePositive(otherC.getCountTruePositive());
                    thisC.incrementFalsePositive(otherC.getCountFalsePositive());
                }
            }
        }
    }

    private void assertHasBeenFit(int classIdx) {
        if (this.countActualPositive == null) {
            throw new IllegalStateException("Cannot get results: no data has been collected");
        }
        if (classIdx < 0 || classIdx >= this.countActualPositive.length) {
            throw new IllegalArgumentException("Invalid class index (" + classIdx + "): must be in range 0 to numClasses = " + this.countActualPositive.length);
        }
    }

    public int getThresholdSteps() {
        return this.thresholdSteps;
    }

    public long[] getCountActualPositive() {
        return this.countActualPositive;
    }

    public long[] getCountActualNegative() {
        return this.countActualNegative;
    }

    public Map<Integer, Map<Double, ROC.CountsForThreshold>> getCounts() {
        return this.counts;
    }
}

