/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval;

import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.eval.ConfusionMatrix;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Evaluation
implements Serializable {
    protected Counter<Integer> truePositives = new Counter();
    protected Counter<Integer> falsePositives = new Counter();
    protected Counter<Integer> trueNegatives = new Counter();
    protected Counter<Integer> falseNegatives = new Counter();
    protected ConfusionMatrix<Integer> confusion;
    protected int numRowCounter = 0;
    protected List<String> labelsList = new ArrayList<String>();
    protected static Logger log = LoggerFactory.getLogger(Evaluation.class);
    protected static final double DEFAULT_EDGE_VALUE = 0.0;

    public Evaluation() {
    }

    public Evaluation(int numClasses) {
        this(Evaluation.createLabels(numClasses));
    }

    public Evaluation(List<String> labels) {
        this.labelsList = labels;
        if (labels != null) {
            this.createConfusion(labels.size());
        }
    }

    public Evaluation(Map<Integer, String> labels) {
        this(Evaluation.createLabelsFromMap(labels));
    }

    private static List<String> createLabels(int numClasses) {
        if (numClasses == 1) {
            numClasses = 2;
        }
        ArrayList<String> list = new ArrayList<String>(numClasses);
        for (int i = 0; i < numClasses; ++i) {
            list.add(String.valueOf(i));
        }
        return list;
    }

    private static List<String> createLabelsFromMap(Map<Integer, String> labels) {
        int size = labels.size();
        ArrayList<String> labelsList = new ArrayList<String>(size);
        for (int i = 0; i < size; ++i) {
            String str = labels.get(i);
            if (str == null) {
                throw new IllegalArgumentException("Invalid labels map: missing key for class " + i + " (expect integers 0 to " + (size - 1) + ")");
            }
            labelsList.add(str);
        }
        return labelsList;
    }

    private void createConfusion(int nClasses) {
        ArrayList<Integer> classes = new ArrayList<Integer>();
        for (int i = 0; i < nClasses; ++i) {
            classes.add(i);
        }
        this.confusion = new ConfusionMatrix(classes);
    }

    public void eval(INDArray trueLabels, INDArray input, ComputationGraph network) {
        this.eval(trueLabels, network.output(false, input)[0]);
    }

    public void eval(INDArray trueLabels, INDArray input, MultiLayerNetwork network) {
        this.eval(trueLabels, network.output(input, Layer.TrainingMode.TEST));
    }

    public void eval(INDArray realOutcomes, INDArray guesses) {
        this.numRowCounter += realOutcomes.shape()[0];
        if (this.confusion == null) {
            int nClasses = realOutcomes.columns();
            if (nClasses == 1) {
                nClasses = 2;
            }
            this.labelsList = new ArrayList<String>(nClasses);
            for (int i = 0; i < nClasses; ++i) {
                this.labelsList.add(String.valueOf(i));
            }
            this.createConfusion(nClasses);
        }
        if (realOutcomes.length() != guesses.length()) {
            throw new IllegalArgumentException("Unable to evaluate. Outcome matrices not same length");
        }
        int nCols = realOutcomes.columns();
        int nRows = realOutcomes.rows();
        if (nCols == 1) {
            INDArray binaryGuesses = guesses.gt((Number)0.5);
            int tp = binaryGuesses.mul(realOutcomes).sumNumber().intValue();
            int fp = binaryGuesses.mul((Number)-1.0).addi((Number)1.0).muli(realOutcomes).sumNumber().intValue();
            int fn = binaryGuesses.mul(realOutcomes.mul((Number)-1.0).addi((Number)1.0)).sumNumber().intValue();
            int tn = nRows - tp - fp - fn;
            this.confusion.add(1, 1, tp);
            this.confusion.add(1, 0, fn);
            this.confusion.add(0, 1, fp);
            this.confusion.add(0, 0, tn);
            this.truePositives.incrementCount(1, tp);
            this.falsePositives.incrementCount(1, fp);
            this.falseNegatives.incrementCount(1, fp);
            this.trueNegatives.incrementCount(1, tp);
            this.truePositives.incrementCount(0, tn);
            this.falsePositives.incrementCount(0, fn);
            this.falseNegatives.incrementCount(0, fn);
            this.trueNegatives.incrementCount(0, tn);
        } else {
            INDArray guessIndex = Nd4j.argMax((INDArray)guesses, (int[])new int[]{1});
            INDArray realOutcomeIndex = Nd4j.argMax((INDArray)realOutcomes, (int[])new int[]{1});
            int nExamples = guessIndex.length();
            for (int i = 0; i < nExamples; ++i) {
                this.confusion.add((int)realOutcomeIndex.getDouble(i), (int)guessIndex.getDouble(i));
            }
            for (int col = 0; col < nCols; ++col) {
                INDArray colBinaryGuesses = guessIndex.eps((Number)col);
                INDArray colRealOutcomes = realOutcomes.getColumn(col);
                int colTp = colBinaryGuesses.mul(colRealOutcomes).sumNumber().intValue();
                int colFp = colBinaryGuesses.mul(colRealOutcomes.mul((Number)-1.0).addi((Number)1.0)).sumNumber().intValue();
                int colFn = colBinaryGuesses.mul((Number)-1.0).addi((Number)1.0).muli(colRealOutcomes).sumNumber().intValue();
                int colTn = nRows - colTp - colFp - colFn;
                this.truePositives.incrementCount(col, colTp);
                this.falsePositives.incrementCount(col, colFp);
                this.falseNegatives.incrementCount(col, colFn);
                this.trueNegatives.incrementCount(col, colTn);
            }
        }
    }

    public void evalTimeSeries(INDArray labels, INDArray predicted) {
        if (labels.rank() == 2 && predicted.rank() == 2) {
            this.eval(labels, predicted);
        }
        if (labels.rank() != 3) {
            throw new IllegalArgumentException("Invalid input: labels are not rank 3 (rank=" + labels.rank() + ")");
        }
        if (!Arrays.equals(labels.shape(), predicted.shape())) {
            throw new IllegalArgumentException("Labels and predicted have different shapes: labels=" + Arrays.toString(labels.shape()) + ", predicted=" + Arrays.toString(predicted.shape()));
        }
        if (labels.ordering() == 'f') {
            labels = Shape.toOffsetZeroCopy((INDArray)labels, (char)'c');
        }
        if (predicted.ordering() == 'f') {
            predicted = Shape.toOffsetZeroCopy((INDArray)predicted, (char)'c');
        }
        int[] shape = labels.shape();
        labels = labels.permute(new int[]{0, 2, 1});
        labels = labels.reshape(shape[0] * shape[2], shape[1]);
        predicted = predicted.permute(new int[]{0, 2, 1});
        predicted = predicted.reshape(shape[0] * shape[2], shape[1]);
        this.eval(labels, predicted);
    }

    public void evalTimeSeries(INDArray labels, INDArray predicted, INDArray outputMask) {
        int totalOutputExamples = outputMask.sumNumber().intValue();
        int outSize = labels.size(1);
        INDArray labels2d = Nd4j.create((int)totalOutputExamples, (int)outSize);
        INDArray predicted2d = Nd4j.create((int)totalOutputExamples, (int)outSize);
        int rowCount = 0;
        for (int ex = 0; ex < outputMask.size(0); ++ex) {
            for (int t = 0; t < outputMask.size(1); ++t) {
                if (outputMask.getDouble(ex, t) == 0.0) continue;
                labels2d.putRow(rowCount, labels.get(new INDArrayIndex[]{NDArrayIndex.point((int)ex), NDArrayIndex.all(), NDArrayIndex.point((int)t)}));
                predicted2d.putRow(rowCount, predicted.get(new INDArrayIndex[]{NDArrayIndex.point((int)ex), NDArrayIndex.all(), NDArrayIndex.point((int)t)}));
                ++rowCount;
            }
        }
        this.eval(labels2d, predicted2d);
    }

    public void eval(int predictedIdx, int actualIdx) {
        ++this.numRowCounter;
        if (this.confusion == null) {
            throw new UnsupportedOperationException("Cannot evaluate single example without initializing confusion matrix first");
        }
        this.addToConfusion(actualIdx, predictedIdx);
        if (predictedIdx == actualIdx) {
            this.incrementTruePositives(predictedIdx);
            for (Integer clazz : this.confusion.getClasses()) {
                if (clazz == predictedIdx) continue;
                this.trueNegatives.incrementCount(clazz, 1.0);
            }
        } else {
            this.incrementFalseNegatives(actualIdx);
            this.incrementFalsePositives(predictedIdx);
            for (Integer clazz : this.confusion.getClasses()) {
                if (clazz == predictedIdx || clazz == actualIdx) continue;
                this.trueNegatives.incrementCount(clazz, 1.0);
            }
        }
    }

    public String stats() {
        return this.stats(false);
    }

    public String stats(boolean suppressWarnings) {
        StringBuilder builder = new StringBuilder().append("\n");
        StringBuilder warnings = new StringBuilder();
        List<Integer> classes = this.confusion.getClasses();
        for (Integer clazz : classes) {
            String actual = this.resolveLabelForClass(clazz);
            for (Integer clazz2 : classes) {
                int count = this.confusion.getCount(clazz, clazz2);
                if (count == 0) continue;
                String expected = this.resolveLabelForClass(clazz2);
                builder.append(String.format("Examples labeled as %s classified by model as %s: %d times%n", actual, expected, count));
            }
            if (suppressWarnings || this.truePositives.getCount(clazz) != 0.0) continue;
            if (this.falsePositives.getCount(clazz) == 0.0) {
                warnings.append(String.format("Warning: class %s was never predicted by the model. This class was excluded from the average precision%n", actual));
            }
            if (this.falseNegatives.getCount(clazz) != 0.0) continue;
            warnings.append(String.format("Warning: class %s has never appeared as a true label. This class was excluded from the average recall%n", actual));
        }
        builder.append("\n");
        builder.append((CharSequence)warnings);
        DecimalFormat df = new DecimalFormat("#.####");
        double acc = this.accuracy();
        double prec = this.precision();
        double rec = this.recall();
        double f1 = this.f1();
        builder.append("\n==========================Scores========================================");
        builder.append("\n Accuracy:  ").append(Evaluation.format(df, acc));
        builder.append("\n Precision: ").append(Evaluation.format(df, prec));
        builder.append("\n Recall:    ").append(Evaluation.format(df, rec));
        builder.append("\n F1 Score:  ").append(Evaluation.format(df, f1));
        builder.append("\n========================================================================");
        return builder.toString();
    }

    private static String format(DecimalFormat f, double num) {
        if (Double.isNaN(num) || Double.isInfinite(num)) {
            return String.valueOf(num);
        }
        return f.format(num);
    }

    private String resolveLabelForClass(Integer clazz) {
        if (this.labelsList != null && this.labelsList.size() > clazz) {
            return this.labelsList.get(clazz);
        }
        return clazz.toString();
    }

    public double precision(Integer classLabel) {
        return this.precision(classLabel, 0.0);
    }

    public double precision(Integer classLabel, double edgeCase) {
        double tpCount = this.truePositives.getCount(classLabel);
        double fpCount = this.falsePositives.getCount(classLabel);
        if (tpCount == 0.0 && fpCount == 0.0) {
            return edgeCase;
        }
        return tpCount / (tpCount + fpCount);
    }

    public double precision() {
        double precisionAcc = 0.0;
        int classCount = 0;
        for (Integer classLabel : this.confusion.getClasses()) {
            double precision = this.precision(classLabel, -1.0);
            if (precision == -1.0) continue;
            precisionAcc += this.precision(classLabel);
            ++classCount;
        }
        return precisionAcc / (double)classCount;
    }

    public double recall(Integer classLabel) {
        return this.recall(classLabel, 0.0);
    }

    public double recall(Integer classLabel, double edgeCase) {
        double tpCount = this.truePositives.getCount(classLabel);
        double fnCount = this.falseNegatives.getCount(classLabel);
        if (tpCount == 0.0 && fnCount == 0.0) {
            return edgeCase;
        }
        return tpCount / (tpCount + fnCount);
    }

    public double recall() {
        double recallAcc = 0.0;
        int classCount = 0;
        for (Integer classLabel : this.confusion.getClasses()) {
            double recall = this.recall(classLabel, -1.0);
            if (recall == -1.0) continue;
            recallAcc += this.recall(classLabel);
            ++classCount;
        }
        return recallAcc / (double)classCount;
    }

    public double falsePositiveRate(Integer classLabel) {
        return this.recall(classLabel, 0.0);
    }

    public double falsePositiveRate(Integer classLabel, double edgeCase) {
        double fpCount = this.falsePositives.getCount(classLabel);
        double tnCount = this.trueNegatives.getCount(classLabel);
        if (fpCount == 0.0 && tnCount == 0.0) {
            return edgeCase;
        }
        return fpCount / (fpCount + tnCount);
    }

    public double falsePositiveRate() {
        double fprAlloc = 0.0;
        int classCount = 0;
        for (Integer classLabel : this.confusion.getClasses()) {
            double fpr = this.falsePositiveRate(classLabel, -1.0);
            if (fpr == -1.0) continue;
            fprAlloc += this.falsePositiveRate(classLabel);
            ++classCount;
        }
        return fprAlloc / (double)classCount;
    }

    public double falseNegativeRate(Integer classLabel) {
        return this.recall(classLabel, 0.0);
    }

    public double falseNegativeRate(Integer classLabel, double edgeCase) {
        double fnCount = this.falseNegatives.getCount(classLabel);
        double tpCount = this.truePositives.getCount(classLabel);
        if (fnCount == 0.0 && tpCount == 0.0) {
            return edgeCase;
        }
        return fnCount / (fnCount + tpCount);
    }

    public double falseNegativeRate() {
        double fnrAlloc = 0.0;
        int classCount = 0;
        for (Integer classLabel : this.confusion.getClasses()) {
            double fnr = this.falseNegativeRate(classLabel, -1.0);
            if (fnr == -1.0) continue;
            fnrAlloc += this.falseNegativeRate(classLabel);
            ++classCount;
        }
        return fnrAlloc / (double)classCount;
    }

    public double falseAlarmRate() {
        return (this.falsePositiveRate() + this.falseNegativeRate()) / 2.0;
    }

    public double f1(Integer classLabel) {
        double precision = this.precision(classLabel);
        double recall = this.recall(classLabel);
        if (precision == 0.0 || recall == 0.0) {
            return 0.0;
        }
        return 2.0 * (precision * recall / (precision + recall));
    }

    public double f1() {
        double precision = this.precision();
        double recall = this.recall();
        if (precision == 0.0 || recall == 0.0) {
            return 0.0;
        }
        return 2.0 * (precision * recall / (precision + recall));
    }

    public double accuracy() {
        int nClasses = this.confusion.getClasses().size();
        int countCorrect = 0;
        for (int i = 0; i < nClasses; ++i) {
            countCorrect += this.confusion.getCount(i, i);
        }
        return (double)countCorrect / (double)this.getNumRowCounter();
    }

    public Map<Integer, Integer> truePositives() {
        return this.convertToMap(this.truePositives, this.confusion.getClasses().size());
    }

    public Map<Integer, Integer> trueNegatives() {
        return this.convertToMap(this.trueNegatives, this.confusion.getClasses().size());
    }

    public Map<Integer, Integer> falsePositives() {
        return this.convertToMap(this.falsePositives, this.confusion.getClasses().size());
    }

    public Map<Integer, Integer> falseNegatives() {
        return this.convertToMap(this.falseNegatives, this.confusion.getClasses().size());
    }

    public Map<Integer, Integer> negative() {
        return this.addMapsByKey(this.trueNegatives(), this.falsePositives());
    }

    public Map<Integer, Integer> positive() {
        return this.addMapsByKey(this.truePositives(), this.falseNegatives());
    }

    private Map<Integer, Integer> convertToMap(Counter<Integer> counter, int maxCount) {
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        for (int i = 0; i < maxCount; ++i) {
            map.put(i, (int)counter.getCount(i));
        }
        return map;
    }

    private Map<Integer, Integer> addMapsByKey(Map<Integer, Integer> first, Map<Integer, Integer> second) {
        HashMap<Integer, Integer> out = new HashMap<Integer, Integer>();
        HashSet<Integer> keys = new HashSet<Integer>(first.keySet());
        keys.addAll(second.keySet());
        for (Integer i : keys) {
            Integer f = first.get(i);
            Integer s = second.get(i);
            if (f == null) {
                f = 0;
            }
            if (s == null) {
                s = 0;
            }
            out.put(i, f + s);
        }
        return out;
    }

    public void incrementTruePositives(Integer classLabel) {
        this.truePositives.incrementCount(classLabel, 1.0);
    }

    public void incrementTrueNegatives(Integer classLabel) {
        this.trueNegatives.incrementCount(classLabel, 1.0);
    }

    public void incrementFalseNegatives(Integer classLabel) {
        this.falseNegatives.incrementCount(classLabel, 1.0);
    }

    public void incrementFalsePositives(Integer classLabel) {
        this.falsePositives.incrementCount(classLabel, 1.0);
    }

    public void addToConfusion(Integer real, Integer guess) {
        this.confusion.add(real, guess);
    }

    public int classCount(Integer clazz) {
        return this.confusion.getActualTotal(clazz);
    }

    public int getNumRowCounter() {
        return this.numRowCounter;
    }

    public String getClassLabel(Integer clazz) {
        return this.resolveLabelForClass(clazz);
    }

    public ConfusionMatrix<Integer> getConfusionMatrix() {
        return this.confusion;
    }

    public void merge(Evaluation other) {
        if (other == null) {
            return;
        }
        this.truePositives.incrementAll(other.truePositives);
        this.falsePositives.incrementAll(other.falsePositives);
        this.trueNegatives.incrementAll(other.trueNegatives);
        this.falseNegatives.incrementAll(other.falseNegatives);
        if (this.confusion == null) {
            if (other.confusion != null) {
                this.confusion = new ConfusionMatrix<Integer>(other.confusion);
            }
        } else if (other.confusion != null) {
            this.confusion.add(other.confusion);
        }
        this.numRowCounter += other.numRowCounter;
        if (this.labelsList.isEmpty()) {
            this.labelsList.addAll(other.labelsList);
        }
    }

    public String confusionToString() {
        int i;
        int nClasses = this.confusion.getClasses().size();
        int maxLabelSize = 0;
        for (String s : this.labelsList) {
            maxLabelSize = Math.max(maxLabelSize, s.length());
        }
        int labelSize = Math.max(maxLabelSize + 5, 10);
        StringBuilder sb = new StringBuilder();
        sb.append("%-3d");
        sb.append("%-");
        sb.append(labelSize);
        sb.append("s | ");
        StringBuilder headerFormat = new StringBuilder();
        headerFormat.append("   %-").append(labelSize).append("s   ");
        for (int i2 = 0; i2 < nClasses; ++i2) {
            sb.append("%7d");
            headerFormat.append("%7d");
        }
        String rowFormat = sb.toString();
        StringBuilder out = new StringBuilder();
        Object[] headerArgs = new Object[nClasses + 1];
        headerArgs[0] = "Predicted:";
        for (i = 0; i < nClasses; ++i) {
            headerArgs[i + 1] = i;
        }
        out.append(String.format(headerFormat.toString(), headerArgs)).append("\n");
        out.append("   Actual:\n");
        for (i = 0; i < nClasses; ++i) {
            Object[] args = new Object[nClasses + 2];
            args[0] = i;
            args[1] = this.labelsList.get(i);
            for (int j = 0; j < nClasses; ++j) {
                args[j + 2] = this.confusion.getCount(i, j);
            }
            out.append(String.format(rowFormat, args));
            out.append("\n");
        }
        return out.toString();
    }
}

