/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval;

import java.io.Serializable;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.eval.BaseEvaluation;
import org.deeplearning4j.eval.curves.PrecisionRecallCurve;
import org.deeplearning4j.eval.curves.RocCurve;
import org.deeplearning4j.eval.serde.ROCSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.OldMulOp;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

@JsonIgnoreProperties(value={"probAndLabel", "exactAllocBlockSize"})
@JsonSerialize(using=ROCSerializer.class)
@JsonTypeInfo(use=JsonTypeInfo.Id.CLASS, include=JsonTypeInfo.As.PROPERTY)
public class ROC
extends BaseEvaluation<ROC> {
    private static final int DEFAULT_EXACT_ALLOC_BLOCK_SIZE = 2048;
    private int thresholdSteps;
    private long countActualPositive;
    private long countActualNegative;
    private final Map<Double, CountsForThreshold> counts = new LinkedHashMap<Double, CountsForThreshold>();
    private Double auc;
    private Double auprc;
    private RocCurve rocCurve;
    private PrecisionRecallCurve prCurve;
    private boolean isExact;
    private INDArray probAndLabel;
    private int exampleCount = 0;
    private boolean rocRemoveRedundantPts;
    private int exactAllocBlockSize;

    public ROC() {
        this(0);
    }

    public ROC(int thresholdSteps) {
        this(thresholdSteps, true);
    }

    public ROC(int thresholdSteps, boolean rocRemoveRedundantPts) {
        this(thresholdSteps, rocRemoveRedundantPts, 2048);
    }

    public ROC(int thresholdSteps, boolean rocRemoveRedundantPts, int exactAllocBlockSize) {
        if (thresholdSteps > 0) {
            this.thresholdSteps = thresholdSteps;
            double step = 1.0 / (double)thresholdSteps;
            for (int i = 0; i <= thresholdSteps; ++i) {
                double currThreshold = (double)i * step;
                this.counts.put(currThreshold, new CountsForThreshold(currThreshold));
            }
            this.isExact = false;
        } else {
            this.isExact = true;
        }
        this.rocRemoveRedundantPts = rocRemoveRedundantPts;
        this.exactAllocBlockSize = exactAllocBlockSize;
    }

    protected INDArray getProbAndLabelUsed() {
        if (this.probAndLabel == null || this.exampleCount == 0) {
            return null;
        }
        return this.probAndLabel.get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)this.exampleCount), NDArrayIndex.all()});
    }

    private double getAuc() {
        if (this.auc != null) {
            return this.auc;
        }
        this.auc = this.calculateAUC();
        return this.auc;
    }

    private double getAuprc() {
        if (this.auprc != null) {
            return this.auprc;
        }
        this.auprc = this.calculateAUCPR();
        return this.auprc;
    }

    @Override
    public void reset() {
        this.countActualPositive = 0L;
        this.countActualNegative = 0L;
        this.counts.clear();
        if (this.isExact) {
            this.probAndLabel = null;
        } else {
            double step = 1.0 / (double)this.thresholdSteps;
            for (int i = 0; i <= this.thresholdSteps; ++i) {
                double currThreshold = (double)i * step;
                this.counts.put(currThreshold, new CountsForThreshold(currThreshold));
            }
        }
        this.exampleCount = 0;
        this.auc = null;
        this.auprc = null;
    }

    @Override
    public String stats() {
        StringBuilder sb = new StringBuilder();
        sb.append("AUC (Area under ROC Curve):                ").append(this.calculateAUC()).append("\n");
        sb.append("AUPRC (Area under Precision/Recall Curve): ").append(this.calculateAUCPR());
        if (!this.isExact) {
            sb.append("\n");
            sb.append("[Note: Thresholded AUC/AUPRC calculation used with ").append(this.thresholdSteps).append(" steps); accuracy may reduced compared to exact mode]");
        }
        return sb.toString();
    }

    @Override
    public void eval(INDArray labels, INDArray predictions) {
        boolean singleOutput;
        if (labels.rank() == 3 && predictions.rank() == 3) {
            this.evalTimeSeries(labels, predictions);
        }
        if (labels.rank() > 2 || predictions.rank() > 2 || labels.size(1) != predictions.size(1) || labels.size(1) > 2) {
            throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(labels.shape()) + ", predictions shape = " + Arrays.toString(predictions.shape()) + "; require rank 2 array with size(1) == 1 or 2");
        }
        double step = 1.0 / (double)this.thresholdSteps;
        boolean bl = singleOutput = labels.size(1) == 1;
        if (this.isExact) {
            INDArray labelClass1;
            INDArray probClass1;
            if (this.probAndLabel == null) {
                int initialSize = Math.max(labels.size(0), this.exactAllocBlockSize);
                this.probAndLabel = Nd4j.create((int[])new int[]{initialSize, 2}, (char)'c');
            }
            if (this.exampleCount + labels.size(0) >= this.probAndLabel.size(0)) {
                int newSize = this.probAndLabel.size(0) + Math.max(this.exactAllocBlockSize, labels.size(0));
                INDArray newProbAndLabel = Nd4j.create((int[])new int[]{newSize, 2}, (char)'c');
                if (this.exampleCount > 0) {
                    newProbAndLabel.get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)this.exampleCount), NDArrayIndex.all()}).assign(this.probAndLabel.get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)this.exampleCount), NDArrayIndex.all()}));
                }
                this.probAndLabel = newProbAndLabel;
            }
            if (singleOutput) {
                probClass1 = predictions;
                labelClass1 = labels;
            } else {
                probClass1 = predictions.getColumn(1);
                labelClass1 = labels.getColumn(1);
            }
            int currMinibatchSize = labels.size(0);
            this.probAndLabel.get(new INDArrayIndex[]{NDArrayIndex.interval((int)this.exampleCount, (int)(this.exampleCount + currMinibatchSize)), NDArrayIndex.point((int)0)}).assign(probClass1);
            this.probAndLabel.get(new INDArrayIndex[]{NDArrayIndex.interval((int)this.exampleCount, (int)(this.exampleCount + currMinibatchSize)), NDArrayIndex.point((int)1)}).assign(labelClass1);
            int countClass1CurrMinibatch = labelClass1.sumNumber().intValue();
            this.countActualPositive += (long)countClass1CurrMinibatch;
            this.countActualNegative += (long)(labels.size(0) - countClass1CurrMinibatch);
        } else {
            INDArray positivePredictedClassColumn;
            INDArray negativeActualClassColumn;
            INDArray positiveActualClassColumn;
            if (singleOutput) {
                positiveActualClassColumn = labels;
                negativeActualClassColumn = labels.rsub((Number)1.0);
                positivePredictedClassColumn = predictions;
            } else {
                positiveActualClassColumn = labels.getColumn(1);
                negativeActualClassColumn = labels.getColumn(0);
                positivePredictedClassColumn = predictions.getColumn(1);
            }
            this.countActualPositive += (long)positiveActualClassColumn.sumNumber().intValue();
            this.countActualNegative += (long)negativeActualClassColumn.sumNumber().intValue();
            INDArray ppc = null;
            INDArray itp = null;
            INDArray ifp = null;
            for (int i = 0; i <= this.thresholdSteps; ++i) {
                INDArray isFalsePositive;
                INDArray isTruePositive;
                double currThreshold = (double)i * step;
                Condition condGeq = Conditions.greaterThanOrEqual((Number)currThreshold);
                Condition condLeq = Conditions.lessThanOrEqual((Number)currThreshold);
                if (ppc == null) {
                    ppc = positivePredictedClassColumn.dup(positiveActualClassColumn.ordering());
                } else {
                    ppc.assign(positivePredictedClassColumn);
                }
                CompareAndSet op = new CompareAndSet(ppc, 1.0, condGeq);
                INDArray predictedClass1 = Nd4j.getExecutioner().execAndReturn((Op)op);
                op = new CompareAndSet(predictedClass1, 0.0, condLeq);
                predictedClass1 = Nd4j.getExecutioner().execAndReturn((Op)op);
                if (i == 0) {
                    isTruePositive = predictedClass1.mul(positiveActualClassColumn);
                    isFalsePositive = predictedClass1.mul(negativeActualClassColumn);
                    itp = isTruePositive;
                    ifp = isFalsePositive;
                } else {
                    isTruePositive = Nd4j.getExecutioner().execAndReturn((TransformOp)new OldMulOp(predictedClass1, positiveActualClassColumn, itp));
                    isFalsePositive = Nd4j.getExecutioner().execAndReturn((TransformOp)new OldMulOp(predictedClass1, negativeActualClassColumn, ifp));
                }
                int truePositiveCount = isTruePositive.sumNumber().intValue();
                int falsePositiveCount = isFalsePositive.sumNumber().intValue();
                CountsForThreshold thresholdCounts = this.counts.get(currThreshold);
                thresholdCounts.incrementTruePositive(truePositiveCount);
                thresholdCounts.incrementFalsePositive(falsePositiveCount);
            }
        }
        this.exampleCount += labels.size(0);
        this.auc = null;
        this.auprc = null;
        this.rocCurve = null;
        this.prCurve = null;
    }

    public PrecisionRecallCurve getPrecisionRecallCurve() {
        int[] fnCountOut;
        int[] fpCountOut;
        int[] tpCountOut;
        double[] recallOut;
        double[] precisionOut;
        double[] thresholdOut;
        if (this.prCurve != null) {
            return this.prCurve;
        }
        if (this.isExact) {
            INDArray pl = this.getProbAndLabelUsed();
            INDArray sorted = Nd4j.sortRows((INDArray)pl, (int)0, (boolean)false);
            INDArray isPositive = sorted.getColumn(1);
            INDArray cumSumPos = isPositive.cumsum(-1);
            int length = sorted.size(0);
            INDArray t = Nd4j.create((int[])new int[]{length + 2, 1});
            t.put(new INDArrayIndex[]{NDArrayIndex.interval((int)1, (int)(length + 1)), NDArrayIndex.all()}, sorted.getColumn(0));
            INDArray linspace = Nd4j.linspace((int)1, (int)length, (int)length);
            INDArray precision = cumSumPos.div(linspace.reshape(cumSumPos.shape()));
            INDArray prec = Nd4j.create((int[])new int[]{length + 2, 1});
            prec.put(new INDArrayIndex[]{NDArrayIndex.interval((int)1, (int)(length + 1)), NDArrayIndex.all()}, precision);
            INDArray rec = Nd4j.create((int[])new int[]{length + 2, 1});
            rec.put(new INDArrayIndex[]{NDArrayIndex.interval((int)1, (int)(length + 1)), NDArrayIndex.all()}, cumSumPos.div((Number)this.countActualPositive));
            t.putScalar(0, 0, 1.0);
            prec.putScalar(0, 0, 1.0);
            rec.putScalar(0, 0, 0.0);
            prec.putScalar(length + 1, 0, cumSumPos.getDouble(cumSumPos.length() - 1) / (double)length);
            rec.putScalar(length + 1, 0, 1.0);
            thresholdOut = t.data().asDouble();
            precisionOut = prec.data().asDouble();
            recallOut = rec.data().asDouble();
            tpCountOut = new int[thresholdOut.length];
            fpCountOut = new int[thresholdOut.length];
            fnCountOut = new int[thresholdOut.length];
            for (int i = 1; i < tpCountOut.length - 1; ++i) {
                tpCountOut[i] = cumSumPos.getInt(new int[]{i - 1});
                fpCountOut[i] = i - tpCountOut[i];
                fnCountOut[i] = (int)this.countActualPositive - tpCountOut[i];
            }
            tpCountOut[tpCountOut.length - 1] = (int)this.countActualPositive;
            fpCountOut[tpCountOut.length - 1] = (int)((long)this.exampleCount - this.countActualPositive);
            fnCountOut[tpCountOut.length - 1] = 0;
            tpCountOut[0] = 0;
            fpCountOut[0] = 0;
            fnCountOut[0] = (int)this.countActualPositive;
            ArrayUtils.reverse((double[])thresholdOut);
            ArrayUtils.reverse((double[])precisionOut);
            ArrayUtils.reverse((double[])recallOut);
            ArrayUtils.reverse((int[])tpCountOut);
            ArrayUtils.reverse((int[])fpCountOut);
            ArrayUtils.reverse((int[])fnCountOut);
            if (this.rocRemoveRedundantPts) {
                Pair<double[][], int[][]> pair = ROC.removeRedundant(thresholdOut, precisionOut, recallOut, tpCountOut, fpCountOut, fnCountOut);
                double[][] temp = (double[][])pair.getFirst();
                int[][] temp2 = (int[][])pair.getSecond();
                thresholdOut = temp[0];
                precisionOut = temp[1];
                recallOut = temp[2];
                tpCountOut = temp2[0];
                fpCountOut = temp2[1];
                fnCountOut = temp2[2];
            }
        } else {
            thresholdOut = new double[this.counts.size()];
            precisionOut = new double[this.counts.size()];
            recallOut = new double[this.counts.size()];
            tpCountOut = new int[this.counts.size()];
            fpCountOut = new int[this.counts.size()];
            fnCountOut = new int[this.counts.size()];
            int i = 0;
            for (Map.Entry<Double, CountsForThreshold> entry : this.counts.entrySet()) {
                double t = entry.getKey();
                CountsForThreshold c = entry.getValue();
                long tpCount = c.getCountTruePositive();
                long fpCount = c.getCountFalsePositive();
                double precision = tpCount == 0L && fpCount == 0L ? 1.0 : (double)tpCount / (double)(tpCount + fpCount);
                double recall = this.countActualPositive == 0L ? 1.0 : (double)tpCount / (double)this.countActualPositive;
                thresholdOut[i] = c.getThreshold();
                precisionOut[i] = precision;
                recallOut[i] = recall;
                tpCountOut[i] = (int)tpCount;
                fpCountOut[i] = (int)fpCount;
                fnCountOut[i] = (int)(this.countActualPositive - tpCount);
                ++i;
            }
        }
        this.prCurve = new PrecisionRecallCurve(thresholdOut, precisionOut, recallOut, tpCountOut, fpCountOut, fnCountOut, this.exampleCount);
        return this.prCurve;
    }

    public RocCurve getRocCurve() {
        if (this.rocCurve != null) {
            return this.rocCurve;
        }
        if (this.isExact) {
            INDArray pl = this.getProbAndLabelUsed();
            INDArray sorted = Nd4j.sortRows((INDArray)pl, (int)0, (boolean)false);
            INDArray isPositive = sorted.getColumn(1);
            INDArray isNegative = sorted.getColumn(1).rsub((Number)1.0);
            INDArray cumSumPos = isPositive.cumsum(-1);
            INDArray cumSumNeg = isNegative.cumsum(-1);
            int length = sorted.size(0);
            INDArray t = Nd4j.create((int[])new int[]{length + 2, 1});
            t.put(new INDArrayIndex[]{NDArrayIndex.interval((int)1, (int)(length + 1)), NDArrayIndex.all()}, sorted.getColumn(0));
            INDArray fpr = Nd4j.create((int[])new int[]{length + 2, 1});
            fpr.put(new INDArrayIndex[]{NDArrayIndex.interval((int)1, (int)(length + 1)), NDArrayIndex.all()}, cumSumNeg.div((Number)this.countActualNegative));
            INDArray tpr = Nd4j.create((int[])new int[]{length + 2, 1});
            tpr.put(new INDArrayIndex[]{NDArrayIndex.interval((int)1, (int)(length + 1)), NDArrayIndex.all()}, cumSumPos.div((Number)this.countActualPositive));
            t.putScalar(0, 0, 1.0);
            fpr.putScalar(0, 0, 0.0);
            tpr.putScalar(0, 0, 0.0);
            fpr.putScalar(length + 1, 0, 1.0);
            tpr.putScalar(length + 1, 0, 1.0);
            double[] x_fpr_out = fpr.data().asDouble();
            double[] y_tpr_out = tpr.data().asDouble();
            double[] tOut = t.data().asDouble();
            if (this.rocRemoveRedundantPts) {
                Pair<double[][], int[][]> p = ROC.removeRedundant(tOut, x_fpr_out, y_tpr_out, null, null, null);
                double[][] temp = (double[][])p.getFirst();
                tOut = temp[0];
                x_fpr_out = temp[1];
                y_tpr_out = temp[2];
            }
            this.rocCurve = new RocCurve(tOut, x_fpr_out, y_tpr_out);
            return this.rocCurve;
        }
        double[][] out = new double[3][this.thresholdSteps + 1];
        int i = 0;
        for (Map.Entry<Double, CountsForThreshold> entry : this.counts.entrySet()) {
            CountsForThreshold c = entry.getValue();
            double tpr = (double)c.getCountTruePositive() / (double)this.countActualPositive;
            double fpr = (double)c.getCountFalsePositive() / (double)this.countActualNegative;
            out[0][i] = c.getThreshold();
            out[1][i] = fpr;
            out[2][i] = tpr;
            ++i;
        }
        return new RocCurve(out[0], out[1], out[2]);
    }

    private static Pair<double[][], int[][]> removeRedundant(double[] threshold, double[] x, double[] y, int[] tpCount, int[] fpCount, int[] fnCount) {
        int[][] nArrayArray;
        double[] t_compacted = new double[threshold.length];
        double[] x_compacted = new double[x.length];
        double[] y_compacted = new double[y.length];
        int[] tp_compacted = null;
        int[] fp_compacted = null;
        int[] fn_compacted = null;
        boolean hasInts = false;
        if (tpCount != null) {
            tp_compacted = new int[tpCount.length];
            fp_compacted = new int[fpCount.length];
            fn_compacted = new int[fnCount.length];
            hasInts = true;
        }
        int lastOutPos = -1;
        for (int i = 0; i < threshold.length; ++i) {
            boolean keep;
            if (i == 0 || i == threshold.length - 1) {
                keep = true;
            } else {
                boolean ommitSameY = y[i - 1] == y[i] && y[i] == y[i + 1];
                boolean ommitSameX = x[i - 1] == x[i] && x[i] == x[i + 1];
                boolean bl = keep = !ommitSameX && !ommitSameY;
            }
            if (!keep) continue;
            t_compacted[++lastOutPos] = threshold[i];
            y_compacted[lastOutPos] = y[i];
            x_compacted[lastOutPos] = x[i];
            if (!hasInts) continue;
            tp_compacted[lastOutPos] = tpCount[i];
            fp_compacted[lastOutPos] = fpCount[i];
            fn_compacted[lastOutPos] = fnCount[i];
        }
        if (lastOutPos < x.length - 1) {
            t_compacted = Arrays.copyOfRange(t_compacted, 0, lastOutPos + 1);
            x_compacted = Arrays.copyOfRange(x_compacted, 0, lastOutPos + 1);
            y_compacted = Arrays.copyOfRange(y_compacted, 0, lastOutPos + 1);
            if (hasInts) {
                tp_compacted = Arrays.copyOfRange(tp_compacted, 0, lastOutPos + 1);
                fp_compacted = Arrays.copyOfRange(fp_compacted, 0, lastOutPos + 1);
                fn_compacted = Arrays.copyOfRange(fn_compacted, 0, lastOutPos + 1);
            }
        }
        double[][] dArrayArray = new double[][]{t_compacted, x_compacted, y_compacted};
        if (hasInts) {
            int[][] nArrayArray2 = new int[3][];
            nArrayArray2[0] = tp_compacted;
            nArrayArray2[1] = fp_compacted;
            nArrayArray = nArrayArray2;
            nArrayArray2[2] = fn_compacted;
        } else {
            nArrayArray = null;
        }
        return new Pair((Object)dArrayArray, nArrayArray);
    }

    public double calculateAUC() {
        if (this.auc != null) {
            return this.auc;
        }
        if (this.exampleCount == 0) {
            return Double.NaN;
        }
        this.auc = this.getRocCurve().calculateAUC();
        return this.auc;
    }

    public double calculateAUCPR() {
        if (this.auprc != null) {
            return this.auprc;
        }
        if (this.exampleCount == 0) {
            return Double.NaN;
        }
        this.auprc = this.getPrecisionRecallCurve().calculateAUPRC();
        return this.auprc;
    }

    @Override
    public void merge(ROC other) {
        if (this.thresholdSteps != other.thresholdSteps) {
            throw new UnsupportedOperationException("Cannot merge ROC instances with different numbers of threshold steps (" + this.thresholdSteps + " vs. " + other.thresholdSteps + ")");
        }
        this.countActualPositive += other.countActualPositive;
        this.countActualNegative += other.countActualNegative;
        this.auc = null;
        this.auprc = null;
        this.rocCurve = null;
        this.prCurve = null;
        if (this.isExact) {
            if (other.exampleCount == 0) {
                return;
            }
            if (this.exampleCount == 0) {
                this.exampleCount = other.exampleCount;
                this.probAndLabel = other.probAndLabel;
                return;
            }
            if (this.exampleCount + other.exampleCount > this.probAndLabel.size(0)) {
                int newSize = this.probAndLabel.size(0) + Math.max(other.probAndLabel.size(0), this.exactAllocBlockSize);
                INDArray newProbAndLabel = Nd4j.create((int)newSize, (int)2);
                newProbAndLabel.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)this.exampleCount), NDArrayIndex.all()}, this.probAndLabel.get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)this.exampleCount), NDArrayIndex.all()}));
                this.probAndLabel = newProbAndLabel;
            }
            INDArray toPut = other.probAndLabel.get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)other.exampleCount), NDArrayIndex.all()});
            this.probAndLabel.put(new INDArrayIndex[]{NDArrayIndex.interval((int)this.exampleCount, (int)(this.exampleCount + other.exampleCount)), NDArrayIndex.all()}, toPut);
        } else {
            for (Double d : this.counts.keySet()) {
                CountsForThreshold cft = this.counts.get(d);
                CountsForThreshold otherCft = other.counts.get(d);
                CountsForThreshold countsForThreshold = cft;
                countsForThreshold.countTruePositive = countsForThreshold.countTruePositive + otherCft.countTruePositive;
                countsForThreshold = cft;
                countsForThreshold.countFalsePositive = countsForThreshold.countFalsePositive + otherCft.countFalsePositive;
            }
        }
        this.exampleCount += other.exampleCount;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ROC)) {
            return false;
        }
        ROC other = (ROC)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (this.getThresholdSteps() != other.getThresholdSteps()) {
            return false;
        }
        if (this.getCountActualPositive() != other.getCountActualPositive()) {
            return false;
        }
        if (this.getCountActualNegative() != other.getCountActualNegative()) {
            return false;
        }
        Map<Double, CountsForThreshold> this$counts = this.getCounts();
        Map<Double, CountsForThreshold> other$counts = other.getCounts();
        if (this$counts == null ? other$counts != null : !((Object)this$counts).equals(other$counts)) {
            return false;
        }
        if (this.isExact() != other.isExact()) {
            return false;
        }
        if (this.getExampleCount() != other.getExampleCount()) {
            return false;
        }
        return this.isRocRemoveRedundantPts() == other.isRocRemoveRedundantPts();
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof ROC;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        result = result * 59 + this.getThresholdSteps();
        long $countActualPositive = this.getCountActualPositive();
        result = result * 59 + (int)($countActualPositive >>> 32 ^ $countActualPositive);
        long $countActualNegative = this.getCountActualNegative();
        result = result * 59 + (int)($countActualNegative >>> 32 ^ $countActualNegative);
        Map<Double, CountsForThreshold> $counts = this.getCounts();
        result = result * 59 + ($counts == null ? 43 : ((Object)$counts).hashCode());
        result = result * 59 + (this.isExact() ? 79 : 97);
        result = result * 59 + this.getExampleCount();
        result = result * 59 + (this.isRocRemoveRedundantPts() ? 79 : 97);
        return result;
    }

    public int getThresholdSteps() {
        return this.thresholdSteps;
    }

    public long getCountActualPositive() {
        return this.countActualPositive;
    }

    public long getCountActualNegative() {
        return this.countActualNegative;
    }

    public Map<Double, CountsForThreshold> getCounts() {
        return this.counts;
    }

    public PrecisionRecallCurve getPrCurve() {
        return this.prCurve;
    }

    public boolean isExact() {
        return this.isExact;
    }

    public INDArray getProbAndLabel() {
        return this.probAndLabel;
    }

    public int getExampleCount() {
        return this.exampleCount;
    }

    public boolean isRocRemoveRedundantPts() {
        return this.rocRemoveRedundantPts;
    }

    public int getExactAllocBlockSize() {
        return this.exactAllocBlockSize;
    }

    public void setThresholdSteps(int thresholdSteps) {
        this.thresholdSteps = thresholdSteps;
    }

    public void setCountActualPositive(long countActualPositive) {
        this.countActualPositive = countActualPositive;
    }

    public void setCountActualNegative(long countActualNegative) {
        this.countActualNegative = countActualNegative;
    }

    public void setAuc(Double auc) {
        this.auc = auc;
    }

    public void setAuprc(Double auprc) {
        this.auprc = auprc;
    }

    public void setRocCurve(RocCurve rocCurve) {
        this.rocCurve = rocCurve;
    }

    public void setPrCurve(PrecisionRecallCurve prCurve) {
        this.prCurve = prCurve;
    }

    public void setExact(boolean isExact) {
        this.isExact = isExact;
    }

    public void setProbAndLabel(INDArray probAndLabel) {
        this.probAndLabel = probAndLabel;
    }

    public void setExampleCount(int exampleCount) {
        this.exampleCount = exampleCount;
    }

    public void setRocRemoveRedundantPts(boolean rocRemoveRedundantPts) {
        this.rocRemoveRedundantPts = rocRemoveRedundantPts;
    }

    public void setExactAllocBlockSize(int exactAllocBlockSize) {
        this.exactAllocBlockSize = exactAllocBlockSize;
    }

    @Override
    public String toString() {
        return "ROC(thresholdSteps=" + this.getThresholdSteps() + ", countActualPositive=" + this.getCountActualPositive() + ", countActualNegative=" + this.getCountActualNegative() + ", counts=" + this.getCounts() + ", auc=" + this.getAuc() + ", auprc=" + this.getAuprc() + ", isExact=" + this.isExact() + ", exampleCount=" + this.getExampleCount() + ", rocRemoveRedundantPts=" + this.isRocRemoveRedundantPts() + ")";
    }

    public static class CountsForThreshold
    implements Serializable,
    Cloneable {
        private double threshold;
        private long countTruePositive;
        private long countFalsePositive;

        public CountsForThreshold(double threshold) {
            this(threshold, 0L, 0L);
        }

        public void incrementTruePositive(long count) {
            this.countTruePositive += count;
        }

        public void incrementFalsePositive(long count) {
            this.countFalsePositive += count;
        }

        public CountsForThreshold clone() {
            return new CountsForThreshold(this.threshold, this.countTruePositive, this.countFalsePositive);
        }

        public CountsForThreshold(double threshold, long countTruePositive, long countFalsePositive) {
            this.threshold = threshold;
            this.countTruePositive = countTruePositive;
            this.countFalsePositive = countFalsePositive;
        }

        public double getThreshold() {
            return this.threshold;
        }

        public long getCountTruePositive() {
            return this.countTruePositive;
        }

        public long getCountFalsePositive() {
            return this.countFalsePositive;
        }

        public void setThreshold(double threshold) {
            this.threshold = threshold;
        }

        public void setCountTruePositive(long countTruePositive) {
            this.countTruePositive = countTruePositive;
        }

        public void setCountFalsePositive(long countFalsePositive) {
            this.countFalsePositive = countFalsePositive;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof CountsForThreshold)) {
                return false;
            }
            CountsForThreshold other = (CountsForThreshold)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (Double.compare(this.getThreshold(), other.getThreshold()) != 0) {
                return false;
            }
            if (this.getCountTruePositive() != other.getCountTruePositive()) {
                return false;
            }
            return this.getCountFalsePositive() == other.getCountFalsePositive();
        }

        protected boolean canEqual(Object other) {
            return other instanceof CountsForThreshold;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $threshold = Double.doubleToLongBits(this.getThreshold());
            result = result * 59 + (int)($threshold >>> 32 ^ $threshold);
            long $countTruePositive = this.getCountTruePositive();
            result = result * 59 + (int)($countTruePositive >>> 32 ^ $countTruePositive);
            long $countFalsePositive = this.getCountFalsePositive();
            result = result * 59 + (int)($countFalsePositive >>> 32 ^ $countFalsePositive);
            return result;
        }

        public String toString() {
            return "ROC.CountsForThreshold(threshold=" + this.getThreshold() + ", countTruePositive=" + this.getCountTruePositive() + ", countFalsePositive=" + this.getCountFalsePositive() + ")";
        }

        public CountsForThreshold() {
        }
    }
}

