/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval;

import java.util.SortedSet;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.eval.ConfusionMatrix;
import org.nd4j.linalg.api.ndarray.INDArray;

public class Evaluation {
    private Counter<Integer> truePositives = new Counter();
    private Counter<Integer> falsePositives = new Counter();
    private Counter<Integer> trueNegative = new Counter();
    private Counter<Integer> falseNegatives = new Counter();
    private ConfusionMatrix<Integer> confusion = new ConfusionMatrix();

    public void eval(INDArray realOutcomes, INDArray guesses) {
        if (realOutcomes.length() != guesses.length()) {
            throw new IllegalArgumentException("Unable to evaluate. Outcome matrices not same length");
        }
        for (int i = 0; i < realOutcomes.rows(); ++i) {
            INDArray currRow = realOutcomes.getRow(i);
            INDArray guessRow = guesses.getRow(i);
            double max = currRow.getDouble(0);
            int currMax = 0;
            for (int col = 1; col < currRow.columns(); ++col) {
                if (!(currRow.getDouble(col) > max)) continue;
                max = currRow.getDouble(col);
                currMax = col;
            }
            double max2 = guessRow.getDouble(0);
            int guessMax = 0;
            for (int col = 1; col < guessRow.columns(); ++col) {
                if (!(guessRow.getDouble(col) > max2)) continue;
                max2 = guessRow.getDouble(col);
                guessMax = col;
            }
            this.addToConfusion(currMax, guessMax);
            if (currMax == guessMax) {
                this.incrementTruePositives(guessMax);
                for (Integer clazz : this.confusion.getClasses()) {
                    if (clazz == guessMax) continue;
                    this.trueNegative.incrementCount(clazz, 1.0);
                }
                continue;
            }
            this.incrementFalseNegatives(currMax);
            this.incrementFalsePositives(guessMax);
        }
    }

    public String stats() {
        StringBuilder builder = new StringBuilder().append("\n");
        SortedSet<Integer> classes = this.confusion.getClasses();
        for (Integer clazz : classes) {
            for (Integer clazz2 : classes) {
                int count = this.confusion.getCount(clazz, clazz2);
                if (count == 0) continue;
                builder.append("\nActual Class " + clazz + " was predicted with Predicted " + clazz2 + " with count " + count + " times\n");
            }
        }
        builder.append("\n==========================F1 Scores========================================");
        builder.append("\n " + this.f1());
        builder.append("\n===========================================================================");
        return builder.toString();
    }

    public void addToConfusion(int real, int guess) {
        this.confusion.add(real, guess);
    }

    public int classCount(int i) {
        return this.confusion.getActualTotal(i);
    }

    public int numtimesPredicted(int label) {
        return this.confusion.getPredictedTotal(label);
    }

    public int numTimesPredicted(int actual, int predicted) {
        return this.confusion.getCount(actual, predicted);
    }

    public double precision() {
        double prec = 0.0;
        for (Integer i : this.confusion.getClasses()) {
            prec += this.precision(i);
        }
        return prec / (double)this.confusion.getClasses().size();
    }

    public double trueNegatives() {
        return this.trueNegative.totalCount();
    }

    public double falsePositive() {
        return this.falsePositives.totalCount();
    }

    public double negative() {
        return this.trueNegatives() + this.falseNegatives.totalCount();
    }

    public double positive() {
        return this.truePositives.totalCount() + this.falseNegatives.totalCount();
    }

    public double accuracy() {
        return (this.truePositives.totalCount() + this.trueNegatives()) / (this.positive() + this.negative());
    }

    public double f1() {
        double precision = this.precision();
        double recall = this.recall();
        if (precision == 0.0 || recall == 0.0) {
            return 0.0;
        }
        return 2.0 * (precision * recall / (precision + recall));
    }

    public double f1(int i) {
        double precision = this.precision(i);
        double recall = this.recall();
        if (precision == 0.0 || recall == 0.0) {
            return 0.0;
        }
        return 2.0 * (precision * recall / (precision + recall));
    }

    public double recall() {
        double r = 0.0;
        for (Integer i : this.confusion.getClasses()) {
            r += this.recall(i);
        }
        return r / (double)this.confusion.getClasses().size();
    }

    public double recall(int i) {
        if (this.truePositives.getCount(i) == 0.0) {
            return 0.0;
        }
        return this.truePositives.getCount(i) / (this.truePositives.getCount(i) + this.falseNegatives.getCount(i));
    }

    public double precision(int i) {
        if (this.truePositives.getCount(i) == 0.0) {
            return 0.0;
        }
        return this.truePositives.getCount(i) / (this.truePositives.getCount(i) + this.falsePositives.getCount(i));
    }

    public void incrementTruePositives(int i) {
        this.truePositives.incrementCount(i, 1.0);
    }

    public void incrementFalseNegatives(int i) {
        this.falseNegatives.incrementCount(i, 1.0);
    }

    public void incrementFalsePositives(int i) {
        this.falsePositives.incrementCount(i, 1.0);
    }
}

