/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval;

import java.beans.ConstructorProperties;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;

public class ROC
implements Serializable {
    private final int thresholdSteps;
    private long countActualPositive;
    private long countActualNegative;
    private final Map<Double, CountsForThreshold> counts = new LinkedHashMap<Double, CountsForThreshold>();

    public ROC(int thresholdSteps) {
        this.thresholdSteps = thresholdSteps;
        double step = 1.0 / (double)thresholdSteps;
        for (int i = 0; i <= thresholdSteps; ++i) {
            double currThreshold = (double)i * step;
            this.counts.put(currThreshold, new CountsForThreshold(currThreshold));
        }
    }

    public void eval(INDArray labels, INDArray predictions) {
        INDArray positivePredictedClassColumn;
        INDArray negativeActualClassColumn;
        INDArray positiveActualClassColumn;
        boolean singleOutput;
        if (labels.rank() == 3 && predictions.rank() == 3) {
            this.evalTimeSeries(labels, predictions);
        }
        if (labels.rank() > 2 || predictions.rank() > 2 || labels.size(1) != predictions.size(1)) {
            throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(labels.shape()) + ", predictions shape = " + Arrays.toString(predictions.shape()) + "; require rank 2 array with size(1) == 1 or 2");
        }
        double step = 1.0 / (double)this.thresholdSteps;
        boolean bl = singleOutput = labels.size(1) == 1;
        if (singleOutput) {
            positiveActualClassColumn = labels;
            negativeActualClassColumn = labels.rsub((Number)1.0);
            positivePredictedClassColumn = predictions;
        } else {
            positiveActualClassColumn = labels.getColumn(1);
            negativeActualClassColumn = labels.getColumn(0);
            positivePredictedClassColumn = predictions.getColumn(1);
        }
        this.countActualPositive += (long)positiveActualClassColumn.sumNumber().intValue();
        this.countActualNegative += (long)negativeActualClassColumn.sumNumber().intValue();
        for (int i = 0; i <= this.thresholdSteps; ++i) {
            double currThreshold = (double)i * step;
            Condition condGeq = Conditions.greaterThanOrEqual((Number)currThreshold);
            Condition condLeq = Conditions.lessThanOrEqual((Number)currThreshold);
            CompareAndSet op = new CompareAndSet(positivePredictedClassColumn.dup(), 1.0, condGeq);
            INDArray predictedClass1 = Nd4j.getExecutioner().execAndReturn((Op)op);
            op = new CompareAndSet(predictedClass1, 0.0, condLeq);
            predictedClass1 = Nd4j.getExecutioner().execAndReturn((Op)op);
            INDArray isTruePositive = predictedClass1.mul(positiveActualClassColumn);
            INDArray isFalsePositive = predictedClass1.mul(negativeActualClassColumn);
            int truePositiveCount = isTruePositive.sumNumber().intValue();
            int falsePositiveCount = isFalsePositive.sumNumber().intValue();
            CountsForThreshold thresholdCounts = this.counts.get(currThreshold);
            thresholdCounts.incrementTruePositive(truePositiveCount);
            thresholdCounts.incrementFalsePositive(falsePositiveCount);
        }
    }

    public void evalTimeSeries(INDArray labels, INDArray predictions) {
        this.evalTimeSeries(labels, predictions, null);
    }

    public void evalTimeSeries(INDArray labels, INDArray predicted, INDArray outputMask) {
        if (labels.rank() != 3 || predicted.rank() != 3) {
            throw new IllegalArgumentException("Invalid data: expect rank 3 arrays. Got arrays with shapes labels=" + Arrays.toString(labels.shape()) + ", predictions=" + Arrays.toString(predicted.shape()));
        }
        labels = labels.dup('f');
        predicted = predicted.dup('f');
        INDArray labels2d = ROC.reshape2d(labels);
        INDArray predicted2d = ROC.reshape2d(predicted);
        if (outputMask == null) {
            this.eval(labels2d, predicted2d);
            return;
        }
        INDArray oneDMask = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(outputMask);
        float[] f = oneDMask.dup().data().asFloat();
        int[] rowsToPull = new int[f.length];
        int usedCount = 0;
        for (int i = 0; i < f.length; ++i) {
            if (f[i] != 1.0f) continue;
            rowsToPull[usedCount++] = i;
        }
        rowsToPull = Arrays.copyOfRange(rowsToPull, 0, usedCount);
        labels2d = Nd4j.pullRows((INDArray)labels2d, (int)1, (int[])rowsToPull);
        predicted2d = Nd4j.pullRows((INDArray)predicted2d, (int)1, (int[])rowsToPull);
        this.eval(labels2d, predicted2d);
    }

    private static INDArray reshape2d(INDArray labels) {
        INDArray labels2d;
        int[] labelsShape = labels.shape();
        if (labelsShape[0] == 1) {
            labels2d = labels.tensorAlongDimension(0, new int[]{1, 2}).permutei(new int[]{1, 0});
        } else if (labelsShape[2] == 1) {
            labels2d = labels.tensorAlongDimension(0, new int[]{1, 0});
        } else {
            labels2d = labels.permute(new int[]{0, 2, 1});
            labels2d = labels2d.reshape('f', labelsShape[0] * labelsShape[2], labelsShape[1]);
        }
        return labels2d;
    }

    public List<ROCValue> getResults() {
        ArrayList<ROCValue> out = new ArrayList<ROCValue>(this.counts.size());
        for (Map.Entry<Double, CountsForThreshold> entry : this.counts.entrySet()) {
            double t = entry.getKey();
            CountsForThreshold c = entry.getValue();
            double tpr = (double)c.getCountTruePositive() / (double)this.countActualPositive;
            double fpr = (double)c.getCountFalsePositive() / (double)this.countActualNegative;
            out.add(new ROCValue(t, tpr, fpr));
        }
        return out;
    }

    public double[][] getResultsAsArray() {
        double[][] out = new double[2][this.thresholdSteps + 1];
        int i = 0;
        for (Map.Entry<Double, CountsForThreshold> entry : this.counts.entrySet()) {
            double fpr;
            CountsForThreshold c = entry.getValue();
            double tpr = (double)c.getCountTruePositive() / (double)this.countActualPositive;
            out[0][i] = fpr = (double)c.getCountFalsePositive() / (double)this.countActualNegative;
            out[1][i] = tpr;
            ++i;
        }
        return out;
    }

    public double calculateAUC() {
        List<ROCValue> list = this.getResults();
        double auc = 0.0;
        for (int i = 0; i < list.size() - 1; ++i) {
            ROCValue left = list.get(i);
            ROCValue right = list.get(i + 1);
            double deltaX = Math.abs(right.getFalsePositiveRate() - left.getFalsePositiveRate());
            double avg = (left.getTruePositiveRate() + right.getTruePositiveRate()) / 2.0;
            auc += deltaX * avg;
        }
        return auc;
    }

    public int getThresholdSteps() {
        return this.thresholdSteps;
    }

    public long getCountActualPositive() {
        return this.countActualPositive;
    }

    public long getCountActualNegative() {
        return this.countActualNegative;
    }

    public Map<Double, CountsForThreshold> getCounts() {
        return this.counts;
    }

    private static class CountsForThreshold
    implements Serializable {
        private double threshold;
        private long countTruePositive;
        private long countFalsePositive;

        private CountsForThreshold(double threshold) {
            this(threshold, 0L, 0L);
        }

        private void incrementTruePositive(long count) {
            this.countTruePositive += count;
        }

        private void incrementFalsePositive(long count) {
            this.countFalsePositive += count;
        }

        @ConstructorProperties(value={"threshold", "countTruePositive", "countFalsePositive"})
        public CountsForThreshold(double threshold, long countTruePositive, long countFalsePositive) {
            this.threshold = threshold;
            this.countTruePositive = countTruePositive;
            this.countFalsePositive = countFalsePositive;
        }

        public double getThreshold() {
            return this.threshold;
        }

        public long getCountTruePositive() {
            return this.countTruePositive;
        }

        public long getCountFalsePositive() {
            return this.countFalsePositive;
        }

        public void setThreshold(double threshold) {
            this.threshold = threshold;
        }

        public void setCountTruePositive(long countTruePositive) {
            this.countTruePositive = countTruePositive;
        }

        public void setCountFalsePositive(long countFalsePositive) {
            this.countFalsePositive = countFalsePositive;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof CountsForThreshold)) {
                return false;
            }
            CountsForThreshold other = (CountsForThreshold)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (Double.compare(this.getThreshold(), other.getThreshold()) != 0) {
                return false;
            }
            if (this.getCountTruePositive() != other.getCountTruePositive()) {
                return false;
            }
            return this.getCountFalsePositive() == other.getCountFalsePositive();
        }

        protected boolean canEqual(Object other) {
            return other instanceof CountsForThreshold;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $threshold = Double.doubleToLongBits(this.getThreshold());
            result = result * 59 + (int)($threshold >>> 32 ^ $threshold);
            long $countTruePositive = this.getCountTruePositive();
            result = result * 59 + (int)($countTruePositive >>> 32 ^ $countTruePositive);
            long $countFalsePositive = this.getCountFalsePositive();
            result = result * 59 + (int)($countFalsePositive >>> 32 ^ $countFalsePositive);
            return result;
        }

        public String toString() {
            return "ROC.CountsForThreshold(threshold=" + this.getThreshold() + ", countTruePositive=" + this.getCountTruePositive() + ", countFalsePositive=" + this.getCountFalsePositive() + ")";
        }
    }

    public static class ROCValue {
        private final double threshold;
        private final double truePositiveRate;
        private final double falsePositiveRate;

        @ConstructorProperties(value={"threshold", "truePositiveRate", "falsePositiveRate"})
        public ROCValue(double threshold, double truePositiveRate, double falsePositiveRate) {
            this.threshold = threshold;
            this.truePositiveRate = truePositiveRate;
            this.falsePositiveRate = falsePositiveRate;
        }

        public double getThreshold() {
            return this.threshold;
        }

        public double getTruePositiveRate() {
            return this.truePositiveRate;
        }

        public double getFalsePositiveRate() {
            return this.falsePositiveRate;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ROCValue)) {
                return false;
            }
            ROCValue other = (ROCValue)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (Double.compare(this.getThreshold(), other.getThreshold()) != 0) {
                return false;
            }
            if (Double.compare(this.getTruePositiveRate(), other.getTruePositiveRate()) != 0) {
                return false;
            }
            return Double.compare(this.getFalsePositiveRate(), other.getFalsePositiveRate()) == 0;
        }

        protected boolean canEqual(Object other) {
            return other instanceof ROCValue;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $threshold = Double.doubleToLongBits(this.getThreshold());
            result = result * 59 + (int)($threshold >>> 32 ^ $threshold);
            long $truePositiveRate = Double.doubleToLongBits(this.getTruePositiveRate());
            result = result * 59 + (int)($truePositiveRate >>> 32 ^ $truePositiveRate);
            long $falsePositiveRate = Double.doubleToLongBits(this.getFalsePositiveRate());
            result = result * 59 + (int)($falsePositiveRate >>> 32 ^ $falsePositiveRate);
            return result;
        }

        public String toString() {
            return "ROC.ROCValue(threshold=" + this.getThreshold() + ", truePositiveRate=" + this.getTruePositiveRate() + ", falsePositiveRate=" + this.getFalsePositiveRate() + ")";
        }
    }
}

