/*
 * Decompiled with CFR 0.152.
 */
package tech.tablesaw.api.ml.classification;

import com.google.common.collect.Table;
import com.google.common.collect.TreeBasedTable;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import java.util.ArrayList;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import tech.tablesaw.api.CategoryColumn;
import tech.tablesaw.api.IntColumn;
import tech.tablesaw.api.Table;
import tech.tablesaw.api.ml.classification.ConfusionMatrix;

public class CategoryConfusionMatrix
implements ConfusionMatrix {
    private final com.google.common.collect.Table<Integer, Integer, Integer> table = TreeBasedTable.create();
    private SortedMap<Integer, String> labels = new TreeMap<Integer, String>();
    private CategoryColumn labelColumn;

    public CategoryConfusionMatrix(CategoryColumn labelColumn, SortedSet<String> labels) {
        this.labelColumn = labelColumn;
        int i = 0;
        for (String object : labels) {
            this.labels.put(i, object);
            ++i;
        }
    }

    @Override
    public void increment(Integer predicted, Integer actual) {
        Integer v = (Integer)this.table.get((Object)predicted, (Object)actual);
        if (v == null) {
            this.table.put((Object)predicted, (Object)actual, (Object)1);
        } else {
            this.table.put((Object)predicted, (Object)actual, (Object)(v + 1));
        }
    }

    @Override
    public String toString() {
        return this.toTable().toString();
    }

    @Override
    public Table toTable() {
        com.google.common.collect.Table<String, String, Integer> sortedTable = this.sortedTable();
        Table t = Table.create("Confusion Matrix");
        t.addColumn(new CategoryColumn(""));
        for (Object label : sortedTable.rowKeySet()) {
            t.addColumn(new IntColumn((String)label));
            t.column(0).appendCell("Predicted " + (String)label);
        }
        int n = 0;
        for (String rowLabel : sortedTable.rowKeySet()) {
            int c = 1;
            for (String colLabel : sortedTable.columnKeySet()) {
                Integer value = (Integer)sortedTable.get((Object)rowLabel, (Object)colLabel);
                if (value == null) {
                    t.intColumn(c).append(0);
                } else {
                    t.intColumn(c).append(value);
                    n += value.intValue();
                }
                ++c;
            }
        }
        t.column(0).setName("n = " + n);
        for (int col = 1; col <= sortedTable.columnKeySet().size(); ++col) {
            t.column(col).setName("Actual " + t.column(col).name());
        }
        return t;
    }

    private com.google.common.collect.Table<String, String, Integer> sortedTable() {
        Int2ObjectMap<String> labelKeys = this.labelColumn.dictionaryMap().keyToValueMap();
        TreeBasedTable sortedTable = TreeBasedTable.create();
        TreeSet allValues = new TreeSet();
        allValues.addAll(this.table.columnKeySet());
        allValues.addAll(this.table.rowKeySet());
        ArrayList valuesList = new ArrayList(allValues);
        for (int r = 0; r < valuesList.size(); ++r) {
            for (int c = 0; c < valuesList.size(); ++c) {
                Integer value = (Integer)this.table.get(valuesList.get(r), valuesList.get(c));
                if (value == null) {
                    sortedTable.put(labelKeys.get(r), labelKeys.get(c), (Object)0);
                    continue;
                }
                sortedTable.put(labelKeys.get(r), labelKeys.get(c), (Object)value);
            }
        }
        return sortedTable;
    }

    @Override
    public double accuracy() {
        int hits = 0;
        int misses = 0;
        for (Table.Cell cell : this.table.cellSet()) {
            if (((Integer)cell.getRowKey()).equals(cell.getColumnKey())) {
                hits += ((Integer)cell.getValue()).intValue();
                continue;
            }
            misses += ((Integer)cell.getValue()).intValue();
        }
        return (double)hits / ((double)(hits + misses) * 1.0);
    }
}

