/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.evaluation.classification;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.evaluation.serde.ROCArraySerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

public class ROCBinary
extends BaseEvaluation<ROCBinary> {
    public static final int DEFAULT_STATS_PRECISION = 4;
    @JsonSerialize(using=ROCArraySerializer.class)
    private ROC[] underlying;
    private int thresholdSteps;
    private boolean rocRemoveRedundantPts;
    private List<String> labels;

    public ROCBinary() {
        this(0);
    }

    public ROCBinary(int thresholdSteps) {
        this(thresholdSteps, true);
    }

    public ROCBinary(int thresholdSteps, boolean rocRemoveRedundantPts) {
        this.thresholdSteps = thresholdSteps;
        this.rocRemoveRedundantPts = rocRemoveRedundantPts;
    }

    @Override
    public void reset() {
        this.underlying = null;
    }

    @Override
    public void eval(INDArray labels, INDArray networkPredictions) {
        this.eval(labels, networkPredictions, (INDArray)null);
    }

    @Override
    public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray) {
        if (this.underlying != null && (long)this.underlying.length != labels.size(1)) {
            throw new IllegalStateException("Labels array does not match stored state size. Expected labels array with size " + this.underlying.length + ", got labels array with size " + labels.size(1));
        }
        if (labels.rank() == 3) {
            this.evalTimeSeries(labels, networkPredictions, maskArray);
            return;
        }
        int n = (int)labels.size(1);
        if (this.underlying == null) {
            this.underlying = new ROC[n];
            for (int i = 0; i < n; ++i) {
                this.underlying[i] = new ROC(this.thresholdSteps, this.rocRemoveRedundantPts);
            }
        }
        int[] perExampleNonMaskedIdxs = null;
        for (int i = 0; i < n; ++i) {
            INDArray prob = networkPredictions.getColumn(i);
            INDArray label = labels.getColumn(i);
            if (maskArray != null) {
                int[] rowsToPull;
                INDArray m;
                boolean perExampleMasking = false;
                if (maskArray.isColumnVectorOrScalar()) {
                    m = maskArray;
                    perExampleMasking = true;
                } else {
                    m = maskArray.getColumn(i);
                }
                if (perExampleNonMaskedIdxs != null) {
                    rowsToPull = perExampleNonMaskedIdxs;
                } else {
                    int nonMaskedCount = m.sumNumber().intValue();
                    rowsToPull = new int[nonMaskedCount];
                    long maskSize = m.size(0);
                    int used = 0;
                    int j = 0;
                    while ((long)j < maskSize) {
                        if (m.getDouble((long)j) != 0.0) {
                            rowsToPull[used++] = j;
                        }
                        ++j;
                    }
                    if (perExampleMasking) {
                        perExampleNonMaskedIdxs = rowsToPull;
                    }
                }
                prob = Nd4j.pullRows(prob, 1, rowsToPull);
                label = Nd4j.pullRows(label, 1, rowsToPull);
            }
            this.underlying[i].eval(label, prob);
        }
    }

    @Override
    public void merge(ROCBinary other) {
        if (this.underlying == null) {
            this.underlying = other.underlying;
            return;
        }
        if (other.underlying == null) {
            return;
        }
        if (this.underlying.length != other.underlying.length) {
            throw new UnsupportedOperationException("Cannot merge ROCBinary: this expects " + this.underlying.length + "outputs, other expects " + other.underlying.length + " outputs");
        }
        for (int i = 0; i < this.underlying.length; ++i) {
            this.underlying[i].merge(other.underlying[i]);
        }
    }

    private void assertIndex(int outputNum) {
        if (this.underlying == null) {
            throw new UnsupportedOperationException("ROCBinary does not have any stats: eval must be called first");
        }
        if (outputNum < 0 || outputNum >= this.underlying.length) {
            throw new IllegalArgumentException("Invalid input: output number must be between 0 and " + (outputNum - 1));
        }
    }

    public int numLabels() {
        if (this.underlying == null) {
            return -1;
        }
        return this.underlying.length;
    }

    public long getCountActualPositive(int outputNum) {
        this.assertIndex(outputNum);
        return this.underlying[outputNum].getCountActualPositive();
    }

    public long getCountActualNegative(int outputNum) {
        this.assertIndex(outputNum);
        return this.underlying[outputNum].getCountActualNegative();
    }

    public RocCurve getRocCurve(int outputNum) {
        this.assertIndex(outputNum);
        return this.underlying[outputNum].getRocCurve();
    }

    public PrecisionRecallCurve getPrecisionRecallCurve(int outputNum) {
        this.assertIndex(outputNum);
        return this.underlying[outputNum].getPrecisionRecallCurve();
    }

    public double calculateAverageAuc() {
        double ret = 0.0;
        for (int i = 0; i < this.numLabels(); ++i) {
            ret += this.calculateAUC(i);
        }
        return ret / (double)this.numLabels();
    }

    public double calculateAverageAUCPR() {
        double ret = 0.0;
        for (int i = 0; i < this.numLabels(); ++i) {
            ret += this.calculateAUCPR(i);
        }
        return ret / (double)this.numLabels();
    }

    public double calculateAUC(int outputNum) {
        this.assertIndex(outputNum);
        return this.underlying[outputNum].calculateAUC();
    }

    public double calculateAUCPR(int outputNum) {
        this.assertIndex(outputNum);
        return this.underlying[outputNum].calculateAUCPR();
    }

    public void setLabelNames(List<String> labels) {
        if (labels == null) {
            this.labels = null;
            return;
        }
        this.labels = new ArrayList<String>(labels);
    }

    @Override
    public String stats() {
        return this.stats(4);
    }

    public String stats(int printPrecision) {
        StringBuilder sb = new StringBuilder();
        int maxLabelsLength = 15;
        if (this.labels != null) {
            for (String s : this.labels) {
                maxLabelsLength = Math.max(s.length(), maxLabelsLength);
            }
        }
        String patternHeader = "%-" + (maxLabelsLength + 5) + "s%-12s%-10s%-10s";
        String header = String.format(patternHeader, "Label", "AUC", "# Pos", "# Neg");
        String pattern = "%-" + (maxLabelsLength + 5) + "s%-12." + printPrecision + "f%-10d%-10d";
        sb.append(header);
        if (this.underlying != null) {
            for (int i = 0; i < this.underlying.length; ++i) {
                double auc = this.calculateAUC(i);
                String label = this.labels == null ? String.valueOf(i) : this.labels.get(i);
                sb.append("\n").append(String.format(pattern, label, auc, this.getCountActualPositive(i), this.getCountActualNegative(i)));
            }
            if (this.thresholdSteps > 0) {
                sb.append("\n");
                sb.append("[Note: Thresholded AUC/AUPRC calculation used with ").append(this.thresholdSteps).append(" steps); accuracy may reduced compared to exact mode]");
            }
        } else {
            sb.append("\n-- No Data --\n");
        }
        return sb.toString();
    }

    public static ROCBinary fromJson(String json) {
        return ROCBinary.fromJson(json, ROCBinary.class);
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ROCBinary)) {
            return false;
        }
        ROCBinary other = (ROCBinary)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (!Arrays.deepEquals(this.getUnderlying(), other.getUnderlying())) {
            return false;
        }
        if (this.getThresholdSteps() != other.getThresholdSteps()) {
            return false;
        }
        if (this.isRocRemoveRedundantPts() != other.isRocRemoveRedundantPts()) {
            return false;
        }
        List<String> this$labels = this.getLabels();
        List<String> other$labels = other.getLabels();
        return !(this$labels == null ? other$labels != null : !((Object)this$labels).equals(other$labels));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof ROCBinary;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        result = result * 59 + Arrays.deepHashCode(this.getUnderlying());
        result = result * 59 + this.getThresholdSteps();
        result = result * 59 + (this.isRocRemoveRedundantPts() ? 79 : 97);
        List<String> $labels = this.getLabels();
        result = result * 59 + ($labels == null ? 43 : ((Object)$labels).hashCode());
        return result;
    }

    public ROC[] getUnderlying() {
        return this.underlying;
    }

    public int getThresholdSteps() {
        return this.thresholdSteps;
    }

    public boolean isRocRemoveRedundantPts() {
        return this.rocRemoveRedundantPts;
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public void setUnderlying(ROC[] underlying) {
        this.underlying = underlying;
    }

    public void setThresholdSteps(int thresholdSteps) {
        this.thresholdSteps = thresholdSteps;
    }

    public void setRocRemoveRedundantPts(boolean rocRemoveRedundantPts) {
        this.rocRemoveRedundantPts = rocRemoveRedundantPts;
    }

    public void setLabels(List<String> labels) {
        this.labels = labels;
    }

    @Override
    public String toString() {
        return "ROCBinary(underlying=" + Arrays.deepToString(this.getUnderlying()) + ", thresholdSteps=" + this.getThresholdSteps() + ", rocRemoveRedundantPts=" + this.isRocRemoveRedundantPts() + ", labels=" + this.getLabels() + ")";
    }
}

