/*
 * Decompiled with CFR 0.152.
 */
package stream.learner.evaluation;

import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import stream.learner.evaluation.TableOfConfusion;

public final class ConfusionMatrix<T extends Serializable> {
    private List<T> labels;
    private long[][] confusionMatrix;

    public ConfusionMatrix() {
        this(new ArrayList());
    }

    public ConfusionMatrix(List<T> labels) {
        this.labels = labels;
        this.confusionMatrix = new long[labels.size()][labels.size()];
    }

    public void addLabel(T additionalLabel) {
        ArrayList<T> additionalLabelAsList = new ArrayList<T>();
        additionalLabelAsList.add(additionalLabel);
        this.addLabels(additionalLabelAsList);
    }

    public void addLabels(List<T> additionalLabels) {
        ArrayList<T> modAdditionalLabels = new ArrayList<T>(additionalLabels);
        modAdditionalLabels.removeAll(this.labels);
        this.labels.addAll(modAdditionalLabels);
        long[][] newConfusionMatrix = new long[this.labels.size()][this.labels.size()];
        for (int i = 0; i < this.confusionMatrix.length; ++i) {
            System.arraycopy(this.confusionMatrix[i], 0, newConfusionMatrix[i], 0, this.confusionMatrix.length);
        }
        this.confusionMatrix = newConfusionMatrix;
    }

    public List<T> getLabels() {
        return this.labels;
    }

    public void add(T truth, T prediction) {
        int indexOfPrediction;
        int indexOfTruth = this.labels.indexOf(truth);
        if (indexOfTruth == -1) {
            indexOfTruth = this.labels.size();
            this.addLabel(truth);
        }
        if ((indexOfPrediction = this.labels.indexOf(prediction)) == -1) {
            indexOfPrediction = this.labels.size();
            this.addLabel(prediction);
        }
        long[] lArray = this.confusionMatrix[indexOfTruth];
        int n = indexOfPrediction;
        lArray[n] = lArray[n] + 1L;
    }

    public Map<T, TableOfConfusion> getTablesOfConfusion() {
        HashMap<Serializable, TableOfConfusion> tablesOfConfusion = new HashMap<Serializable, TableOfConfusion>();
        for (Serializable label : this.labels) {
            tablesOfConfusion.put(label, this.getTableOfConfusion(label));
        }
        return tablesOfConfusion;
    }

    public TableOfConfusion getTableOfConfusion(T label) {
        TableOfConfusion tableOfConfusion = new TableOfConfusion();
        tableOfConfusion.addTruePositive(this.getTruePositiveCount(label));
        tableOfConfusion.addTrueNegative(this.getTrueNegativeCount(label));
        tableOfConfusion.addFalsePositive(this.getFalsePositiveCount(label));
        tableOfConfusion.addFalseNegative(this.getFalseNegativeCount(label));
        return tableOfConfusion;
    }

    public double calculateAccuracy() {
        double correct = 0.0;
        for (int i = 0; i < this.labels.size(); ++i) {
            correct += (double)this.confusionMatrix[i][i];
        }
        double divisor = 0.0;
        for (int i = 0; i < this.labels.size(); ++i) {
            for (int j = 0; j < this.labels.size(); ++j) {
                divisor += (double)this.confusionMatrix[i][j];
            }
        }
        if (divisor == 0.0) {
            return Double.NaN;
        }
        return correct / divisor;
    }

    public long getWeightForLabel(T label) {
        int indexOfLabel = this.labels.indexOf(label);
        long weight = 0L;
        for (int i = 0; i < this.labels.size(); ++i) {
            if (i == indexOfLabel) continue;
            weight += this.confusionMatrix[indexOfLabel][i];
        }
        return weight;
    }

    private long getTruePositiveCount(T label) {
        int indexOfLabel = this.labels.indexOf(label);
        return this.confusionMatrix[indexOfLabel][indexOfLabel];
    }

    private long getTrueNegativeCount(T label) {
        int indexOfLabel = this.labels.indexOf(label);
        long trueNegativeCount = 0L;
        for (int i = 0; i < this.labels.size(); ++i) {
            if (i == indexOfLabel) continue;
            for (int j = 0; j < this.labels.size(); ++j) {
                if (j == indexOfLabel) continue;
                trueNegativeCount += this.confusionMatrix[i][j];
            }
        }
        return trueNegativeCount;
    }

    private long getFalsePositiveCount(T label) {
        int indexOfLabel = this.labels.indexOf(label);
        long falsePositiveCount = 0L;
        for (int i = 0; i < this.labels.size(); ++i) {
            if (i == indexOfLabel) continue;
            falsePositiveCount += this.confusionMatrix[indexOfLabel][i];
        }
        return falsePositiveCount;
    }

    private long getFalseNegativeCount(T label) {
        int indexOfLabel = this.labels.indexOf(label);
        long falseNegativeCount = 0L;
        for (int i = 0; i < this.labels.size(); ++i) {
            if (i == indexOfLabel) continue;
            falseNegativeCount += this.confusionMatrix[i][indexOfLabel];
        }
        return falseNegativeCount;
    }

    public String toString() {
        int i;
        String lineSeparator = System.getProperty("line.separator");
        StringBuilder sb = new StringBuilder("ConfusionMatrix (rows=truth,columns=prediction)").append(lineSeparator).append("values:").append(lineSeparator);
        sb.append(lineSeparator);
        sb.append("     ");
        for (i = 0; i < this.labels.size(); ++i) {
            sb.append(" | pred: " + this.labels.get(i));
        }
        sb.append(lineSeparator);
        for (i = 0; i < this.labels.size(); ++i) {
            sb.append("true:" + this.labels.get(i) + "   |  ");
            for (int j = 0; j < this.labels.size(); ++j) {
                sb.append(" ").append(this.confusionMatrix[i][j]);
            }
            sb.append(lineSeparator);
        }
        sb.append(lineSeparator).append("results:").append(lineSeparator);
        for (Serializable label : this.labels) {
            sb.append(label).append(lineSeparator).append(this.getTableOfConfusion(label));
        }
        return sb.toString();
    }

    public String toHtml() {
        StringBuilder b = new StringBuilder("<table class=\"confusionMatrix\">");
        b.append("<tr>");
        b.append("<td colspan=\"2\" rowspan=\"2\" style=\"border: none;\"></td><th colspan=\"" + this.labels.size() + "\">prediction</th>");
        b.append("</tr>");
        b.append("<tr>");
        for (Serializable l : this.labels) {
            b.append("<th>" + l.toString() + "</th>");
        }
        b.append("<th>Precision</th>");
        b.append("</tr>");
        DecimalFormat fmt = new DecimalFormat("0.00 %");
        for (int i = 0; i < this.labels.size(); ++i) {
            Serializable cur = (Serializable)this.labels.get(i);
            b.append("<tr>");
            if (i == 0) {
                b.append("<th rowspan=\"" + this.labels.size() + "\">true</th>");
            }
            b.append("<th>" + this.labels.get(i) + "</th>");
            Double tp = 0.0;
            Double fp = 0.0;
            for (int j = 0; j < this.labels.size(); ++j) {
                Serializable against = (Serializable)this.labels.get(j);
                if (cur != against) {
                    fp = fp + (double)this.confusionMatrix[i][j];
                } else {
                    tp = tp + (double)this.confusionMatrix[i][j];
                }
                b.append(" <td>").append(this.confusionMatrix[i][j]).append("</td>");
            }
            b.append("<td><nobr>" + fmt.format(tp / (tp + fp)) + "</nobr></td>");
            b.append("</tr>\n");
        }
        b.append("</table>");
        return b.toString();
    }
}

