/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval;

import java.util.Arrays;
import org.deeplearning4j.eval.BaseEvaluation;
import org.deeplearning4j.eval.curves.Histogram;
import org.deeplearning4j.eval.curves.ReliabilityDiagram;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.accum.MatchCondition;
import org.nd4j.linalg.api.ops.impl.transforms.IsMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.lossfunctions.serde.RowVectorDeserializer;
import org.nd4j.linalg.lossfunctions.serde.RowVectorSerializer;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
import org.nd4j.shade.serde.jackson.shaded.NDArrayDeSerializer;
import org.nd4j.shade.serde.jackson.shaded.NDArraySerializer;

public class EvaluationCalibration
extends BaseEvaluation<EvaluationCalibration> {
    public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10;
    public static final int DEFAULT_HISTOGRAM_NUM_BINS = 50;
    private final int reliabilityDiagNumBins;
    private final int histogramNumBins;
    private final boolean excludeEmptyBins;
    @JsonSerialize(using=NDArraySerializer.class)
    @JsonDeserialize(using=NDArrayDeSerializer.class)
    private INDArray rDiagBinPosCount;
    @JsonSerialize(using=NDArraySerializer.class)
    @JsonDeserialize(using=NDArrayDeSerializer.class)
    private INDArray rDiagBinTotalCount;
    @JsonSerialize(using=NDArraySerializer.class)
    @JsonDeserialize(using=NDArrayDeSerializer.class)
    private INDArray rDiagBinSumPredictions;
    @JsonSerialize(using=RowVectorSerializer.class)
    @JsonDeserialize(using=RowVectorDeserializer.class)
    private INDArray labelCountsEachClass;
    @JsonSerialize(using=RowVectorSerializer.class)
    @JsonDeserialize(using=RowVectorDeserializer.class)
    private INDArray predictionCountsEachClass;
    @JsonSerialize(using=RowVectorSerializer.class)
    @JsonDeserialize(using=RowVectorDeserializer.class)
    private INDArray residualPlotOverall;
    @JsonSerialize(using=NDArraySerializer.class)
    @JsonDeserialize(using=NDArrayDeSerializer.class)
    private INDArray residualPlotByLabelClass;
    @JsonSerialize(using=RowVectorSerializer.class)
    @JsonDeserialize(using=RowVectorDeserializer.class)
    private INDArray probHistogramOverall;
    @JsonSerialize(using=NDArraySerializer.class)
    @JsonDeserialize(using=NDArrayDeSerializer.class)
    private INDArray probHistogramByLabelClass;

    public EvaluationCalibration() {
        this(10, 50, true);
    }

    public EvaluationCalibration(int reliabilityDiagNumBins, int histogramNumBins) {
        this(reliabilityDiagNumBins, histogramNumBins, true);
    }

    public EvaluationCalibration(@JsonProperty(value="reliabilityDiagNumBins") int reliabilityDiagNumBins, @JsonProperty(value="histogramNumBins") int histogramNumBins, @JsonProperty(value="excludeEmptyBins") boolean excludeEmptyBins) {
        this.reliabilityDiagNumBins = reliabilityDiagNumBins;
        this.histogramNumBins = histogramNumBins;
        this.excludeEmptyBins = excludeEmptyBins;
    }

    @Override
    public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray) {
        if (labels.rank() == 3) {
            this.evalTimeSeries(labels, networkPredictions, maskArray);
            return;
        }
        int nClasses = labels.size(1);
        if (this.rDiagBinPosCount == null) {
            this.rDiagBinPosCount = Nd4j.create((int)this.reliabilityDiagNumBins, (int)nClasses);
            this.rDiagBinTotalCount = Nd4j.create((int)this.reliabilityDiagNumBins, (int)nClasses);
            this.rDiagBinSumPredictions = Nd4j.create((int)this.reliabilityDiagNumBins, (int)nClasses);
            this.labelCountsEachClass = Nd4j.create((int)1, (int)nClasses);
            this.predictionCountsEachClass = Nd4j.create((int)1, (int)nClasses);
            this.residualPlotOverall = Nd4j.create((int)1, (int)this.histogramNumBins);
            this.residualPlotByLabelClass = Nd4j.create((int)this.histogramNumBins, (int)nClasses);
            this.probHistogramOverall = Nd4j.create((int)1, (int)this.histogramNumBins);
            this.probHistogramByLabelClass = Nd4j.create((int)this.histogramNumBins, (int)nClasses);
        }
        double binSize = 1.0 / (double)this.reliabilityDiagNumBins;
        INDArray p = networkPredictions;
        INDArray l = labels;
        if (maskArray != null) {
            l = maskArray.isColumnVector() ? l.mulColumnVector(maskArray) : l.mul(maskArray);
        }
        for (int j = 0; j < this.reliabilityDiagNumBins; ++j) {
            INDArray geqBinLower = p.gte((Number)((double)j * binSize));
            INDArray ltBinUpper = j == this.reliabilityDiagNumBins - 1 ? p.lte((Number)1.0) : p.lt((Number)((double)(j + 1) * binSize));
            INDArray currBinBitMask = geqBinLower.muli(ltBinUpper);
            if (maskArray != null) {
                if (maskArray.isColumnVector()) {
                    currBinBitMask.muliColumnVector(maskArray);
                } else {
                    currBinBitMask.muli(maskArray);
                }
            }
            INDArray isPosLabelForBin = l.mul(currBinBitMask);
            INDArray maskedProbs = networkPredictions.mul(currBinBitMask);
            INDArray numPredictionsCurrBin = currBinBitMask.sum(new int[]{0});
            this.rDiagBinSumPredictions.getRow(j).addi(maskedProbs.sum(new int[]{0}));
            this.rDiagBinPosCount.getRow(j).addi(isPosLabelForBin.sum(new int[]{0}));
            this.rDiagBinTotalCount.getRow(j).addi(numPredictionsCurrBin);
        }
        this.labelCountsEachClass.addi(labels.sum(new int[]{0}));
        INDArray isPredictedClass = Nd4j.getExecutioner().execAndReturn((TransformOp)new IsMax(p.dup(), new int[]{1}));
        if (maskArray != null) {
            LossUtil.applyMask((INDArray)isPredictedClass, (INDArray)maskArray);
        }
        this.predictionCountsEachClass.addi(isPredictedClass.sum(new int[]{0}));
        INDArray labelsSubPredicted = labels.sub(networkPredictions);
        INDArray maskedProbs = networkPredictions.dup();
        Transforms.abs((INDArray)labelsSubPredicted, (boolean)false);
        if (maskArray != null) {
            INDArray newMask = maskArray.mul((Number)-10);
            labelsSubPredicted.addiColumnVector(newMask);
            maskedProbs.addiColumnVector(newMask);
        }
        INDArray notLabels = Transforms.not((INDArray)labels);
        for (int j = 0; j < this.histogramNumBins; ++j) {
            INDArray ltBinUpperProbs;
            INDArray ltBinUpper;
            INDArray geqBinLower = labelsSubPredicted.gte((Number)((double)j * binSize));
            INDArray geqBinLowerProbs = maskedProbs.gte((Number)((double)j * binSize));
            if (j == this.histogramNumBins - 1) {
                ltBinUpper = labelsSubPredicted.lte((Number)1.0);
                ltBinUpperProbs = maskedProbs.lte((Number)1.0);
            } else {
                ltBinUpper = labelsSubPredicted.lt((Number)((double)(j + 1) * binSize));
                ltBinUpperProbs = maskedProbs.lt((Number)((double)(j + 1) * binSize));
            }
            INDArray currBinBitMask = geqBinLower.muli(ltBinUpper);
            INDArray currBinBitMaskProbs = geqBinLowerProbs.muli(ltBinUpperProbs);
            int newTotalCount = this.residualPlotOverall.getInt(new int[]{0, j}) + currBinBitMask.sumNumber().intValue();
            this.residualPlotOverall.putScalar(0, j, (double)newTotalCount);
            INDArray isPosLabelForBin = l.mul(currBinBitMask);
            this.residualPlotByLabelClass.getRow(j).addi(isPosLabelForBin.sum(new int[]{0}));
            int probNewTotalCount = this.probHistogramOverall.getInt(new int[]{0, j}) + currBinBitMaskProbs.sumNumber().intValue();
            this.probHistogramOverall.putScalar(0, j, (double)probNewTotalCount);
            INDArray isPosLabelForBinProbs = l.mul(currBinBitMaskProbs);
            this.probHistogramByLabelClass.getRow(j).addi(isPosLabelForBinProbs.sum(new int[]{0}));
        }
    }

    @Override
    public void eval(INDArray labels, INDArray networkPredictions) {
        this.eval(labels, networkPredictions, (INDArray)null);
    }

    @Override
    public void merge(EvaluationCalibration other) {
        if (this.reliabilityDiagNumBins != other.reliabilityDiagNumBins) {
            throw new UnsupportedOperationException("Cannot merge EvaluationCalibration instances with different numbers of bins");
        }
        if (other.rDiagBinPosCount == null) {
            return;
        }
        if (this.rDiagBinPosCount == null) {
            this.rDiagBinPosCount = other.rDiagBinPosCount;
            this.rDiagBinTotalCount = other.rDiagBinTotalCount;
            this.rDiagBinSumPredictions = other.rDiagBinSumPredictions;
        }
        this.rDiagBinPosCount.addi(other.rDiagBinPosCount);
        this.rDiagBinTotalCount.addi(other.rDiagBinTotalCount);
        this.rDiagBinSumPredictions.addi(other.rDiagBinSumPredictions);
    }

    @Override
    public void reset() {
        this.rDiagBinPosCount = null;
        this.rDiagBinTotalCount = null;
        this.rDiagBinSumPredictions = null;
    }

    @Override
    public String stats() {
        return "EvaluationCalibration(nBins=" + this.reliabilityDiagNumBins + ")";
    }

    public int numClasses() {
        if (this.rDiagBinTotalCount == null) {
            return -1;
        }
        return this.rDiagBinTotalCount.size(1);
    }

    public ReliabilityDiagram getReliabilityDiagram(int classIdx) {
        INDArray totalCountBins = this.rDiagBinTotalCount.getColumn(classIdx);
        INDArray countPositiveBins = this.rDiagBinPosCount.getColumn(classIdx);
        double[] meanPredictionBins = this.rDiagBinSumPredictions.getColumn(classIdx).div(totalCountBins).data().asDouble();
        double[] fracPositives = countPositiveBins.div(totalCountBins).data().asDouble();
        if (this.excludeEmptyBins) {
            MatchCondition condition = new MatchCondition(totalCountBins, Conditions.equals((Number)0));
            int numZeroBins = Nd4j.getExecutioner().exec((Accumulation)condition, new int[]{Integer.MAX_VALUE}).getInt(new int[]{0});
            if (numZeroBins != 0) {
                double[] mpb = meanPredictionBins;
                double[] fp = fracPositives;
                meanPredictionBins = new double[totalCountBins.length() - numZeroBins];
                fracPositives = new double[meanPredictionBins.length];
                int j = 0;
                for (int i = 0; i < mpb.length; ++i) {
                    if (totalCountBins.getDouble(i) == 0.0) continue;
                    meanPredictionBins[j] = mpb[i];
                    fracPositives[j] = fp[i];
                    ++j;
                }
            }
        }
        String title = "Reliability Diagram: Class " + classIdx;
        return new ReliabilityDiagram(title, meanPredictionBins, fracPositives);
    }

    public int[] getLabelCountsEachClass() {
        return this.labelCountsEachClass == null ? null : this.labelCountsEachClass.data().asInt();
    }

    public int[] getPredictionCountsEachClass() {
        return this.predictionCountsEachClass == null ? null : this.predictionCountsEachClass.data().asInt();
    }

    public Histogram getResidualPlotAllClasses() {
        String title = "Residual Plot - All Predictions and Classes";
        int[] counts = this.residualPlotOverall.data().asInt();
        return new Histogram(title, 0.0, 1.0, counts);
    }

    public Histogram getResidualPlot(int labelClassIdx) {
        String title = "Residual Plot - Predictions for Label Class " + labelClassIdx;
        int[] counts = this.residualPlotByLabelClass.getColumn(labelClassIdx).dup().data().asInt();
        return new Histogram(title, 0.0, 1.0, counts);
    }

    public Histogram getProbabilityHistogramAllClasses() {
        String title = "Network Probabilities Histogram - All Predictions and Classes";
        int[] counts = this.probHistogramOverall.data().asInt();
        return new Histogram(title, 0.0, 1.0, counts);
    }

    public Histogram getProbabilityHistogram(int labelClassIdx) {
        String title = "Network Probabilities Histogram - P(class " + labelClassIdx + ") - Data Labelled Class " + labelClassIdx + " Only";
        int[] counts = this.probHistogramByLabelClass.getColumn(labelClassIdx).dup().data().asInt();
        return new Histogram(title, 0.0, 1.0, counts);
    }

    public int getReliabilityDiagNumBins() {
        return this.reliabilityDiagNumBins;
    }

    public int getHistogramNumBins() {
        return this.histogramNumBins;
    }

    public boolean isExcludeEmptyBins() {
        return this.excludeEmptyBins;
    }

    public INDArray getRDiagBinPosCount() {
        return this.rDiagBinPosCount;
    }

    public INDArray getRDiagBinTotalCount() {
        return this.rDiagBinTotalCount;
    }

    public INDArray getRDiagBinSumPredictions() {
        return this.rDiagBinSumPredictions;
    }

    public INDArray getResidualPlotOverall() {
        return this.residualPlotOverall;
    }

    public INDArray getResidualPlotByLabelClass() {
        return this.residualPlotByLabelClass;
    }

    public INDArray getProbHistogramOverall() {
        return this.probHistogramOverall;
    }

    public INDArray getProbHistogramByLabelClass() {
        return this.probHistogramByLabelClass;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof EvaluationCalibration)) {
            return false;
        }
        EvaluationCalibration other = (EvaluationCalibration)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getReliabilityDiagNumBins() != other.getReliabilityDiagNumBins()) {
            return false;
        }
        if (this.getHistogramNumBins() != other.getHistogramNumBins()) {
            return false;
        }
        if (this.isExcludeEmptyBins() != other.isExcludeEmptyBins()) {
            return false;
        }
        INDArray this$rDiagBinPosCount = this.getRDiagBinPosCount();
        INDArray other$rDiagBinPosCount = other.getRDiagBinPosCount();
        if (this$rDiagBinPosCount == null ? other$rDiagBinPosCount != null : !this$rDiagBinPosCount.equals(other$rDiagBinPosCount)) {
            return false;
        }
        INDArray this$rDiagBinTotalCount = this.getRDiagBinTotalCount();
        INDArray other$rDiagBinTotalCount = other.getRDiagBinTotalCount();
        if (this$rDiagBinTotalCount == null ? other$rDiagBinTotalCount != null : !this$rDiagBinTotalCount.equals(other$rDiagBinTotalCount)) {
            return false;
        }
        INDArray this$rDiagBinSumPredictions = this.getRDiagBinSumPredictions();
        INDArray other$rDiagBinSumPredictions = other.getRDiagBinSumPredictions();
        if (this$rDiagBinSumPredictions == null ? other$rDiagBinSumPredictions != null : !this$rDiagBinSumPredictions.equals(other$rDiagBinSumPredictions)) {
            return false;
        }
        if (!Arrays.equals(this.getLabelCountsEachClass(), other.getLabelCountsEachClass())) {
            return false;
        }
        if (!Arrays.equals(this.getPredictionCountsEachClass(), other.getPredictionCountsEachClass())) {
            return false;
        }
        INDArray this$residualPlotOverall = this.getResidualPlotOverall();
        INDArray other$residualPlotOverall = other.getResidualPlotOverall();
        if (this$residualPlotOverall == null ? other$residualPlotOverall != null : !this$residualPlotOverall.equals(other$residualPlotOverall)) {
            return false;
        }
        INDArray this$residualPlotByLabelClass = this.getResidualPlotByLabelClass();
        INDArray other$residualPlotByLabelClass = other.getResidualPlotByLabelClass();
        if (this$residualPlotByLabelClass == null ? other$residualPlotByLabelClass != null : !this$residualPlotByLabelClass.equals(other$residualPlotByLabelClass)) {
            return false;
        }
        INDArray this$probHistogramOverall = this.getProbHistogramOverall();
        INDArray other$probHistogramOverall = other.getProbHistogramOverall();
        if (this$probHistogramOverall == null ? other$probHistogramOverall != null : !this$probHistogramOverall.equals(other$probHistogramOverall)) {
            return false;
        }
        INDArray this$probHistogramByLabelClass = this.getProbHistogramByLabelClass();
        INDArray other$probHistogramByLabelClass = other.getProbHistogramByLabelClass();
        return !(this$probHistogramByLabelClass == null ? other$probHistogramByLabelClass != null : !this$probHistogramByLabelClass.equals(other$probHistogramByLabelClass));
    }

    protected boolean canEqual(Object other) {
        return other instanceof EvaluationCalibration;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getReliabilityDiagNumBins();
        result = result * 59 + this.getHistogramNumBins();
        result = result * 59 + (this.isExcludeEmptyBins() ? 79 : 97);
        INDArray $rDiagBinPosCount = this.getRDiagBinPosCount();
        result = result * 59 + ($rDiagBinPosCount == null ? 43 : $rDiagBinPosCount.hashCode());
        INDArray $rDiagBinTotalCount = this.getRDiagBinTotalCount();
        result = result * 59 + ($rDiagBinTotalCount == null ? 43 : $rDiagBinTotalCount.hashCode());
        INDArray $rDiagBinSumPredictions = this.getRDiagBinSumPredictions();
        result = result * 59 + ($rDiagBinSumPredictions == null ? 43 : $rDiagBinSumPredictions.hashCode());
        result = result * 59 + Arrays.hashCode(this.getLabelCountsEachClass());
        result = result * 59 + Arrays.hashCode(this.getPredictionCountsEachClass());
        INDArray $residualPlotOverall = this.getResidualPlotOverall();
        result = result * 59 + ($residualPlotOverall == null ? 43 : $residualPlotOverall.hashCode());
        INDArray $residualPlotByLabelClass = this.getResidualPlotByLabelClass();
        result = result * 59 + ($residualPlotByLabelClass == null ? 43 : $residualPlotByLabelClass.hashCode());
        INDArray $probHistogramOverall = this.getProbHistogramOverall();
        result = result * 59 + ($probHistogramOverall == null ? 43 : $probHistogramOverall.hashCode());
        INDArray $probHistogramByLabelClass = this.getProbHistogramByLabelClass();
        result = result * 59 + ($probHistogramByLabelClass == null ? 43 : $probHistogramByLabelClass.hashCode());
        return result;
    }
}

