package com.aliasi.classify;

/* loaded from: input_file:com/aliasi/classify/RankedClassifierEvaluator.class */
public class RankedClassifierEvaluator<E> extends BaseClassifierEvaluator<E> {
    boolean mDefectiveRanking;
    private final int[][] mRankCounts;

    public RankedClassifierEvaluator(RankedClassifier<E> rankedClassifier, String[] strArr, boolean z) {
        super(rankedClassifier, strArr, z);
        this.mDefectiveRanking = false;
        int length = strArr.length;
        this.mRankCounts = new int[length][length];
    }

    public void setClassifier(RankedClassifier<E> rankedClassifier) {
        if (!getClass().equals(RankedClassifierEvaluator.class)) {
            throw new IllegalArgumentException("Require appropriate classifier type. Evaluator class=" + getClass() + " Found classifier.class=" + rankedClassifier.getClass());
        }
        this.mClassifier = rankedClassifier;
    }

    @Override // com.aliasi.classify.BaseClassifierEvaluator
    public RankedClassifier<E> classifier() {
        return (RankedClassifier) super.classifier();
    }

    @Override // com.aliasi.classify.BaseClassifierEvaluator, com.aliasi.corpus.ObjectHandler
    public void handle(Classified<E> classified) {
        E object = classified.getObject();
        String bestCategory = classified.getClassification().bestCategory();
        validateCategory(bestCategory);
        RankedClassification classify = classifier().classify((RankedClassifier<E>) object);
        addClassification(bestCategory, classify, object);
        addRanking(bestCategory, classify);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addRanking(String str, RankedClassification rankedClassification) {
        int categoryToIndex = categoryToIndex(str);
        if (rankedClassification.size() < numCategories()) {
            this.mDefectiveRanking = true;
        }
        for (int i = 0; i < numCategories() && i < rankedClassification.size(); i++) {
            if (rankedClassification.category(i).equals(str)) {
                int[] iArr = this.mRankCounts[categoryToIndex];
                int i2 = i;
                iArr[i2] = iArr[i2] + 1;
                return;
            }
        }
        int[] iArr2 = this.mRankCounts[categoryToIndex];
        int length = this.mCategories.length - 1;
        iArr2[length] = iArr2[length] + 1;
    }

    public int rankCount(String str, int i) {
        validateCategory(str);
        return this.mRankCounts[categoryToIndex(str)][i];
    }

    public double averageRankReference() {
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < numCategories(); i2++) {
            for (int i3 = 0; i3 < numCategories(); i3++) {
                int i4 = this.mRankCounts[i2][i3];
                if (i4 != 0) {
                    i += i4;
                    d += i3 * i4;
                }
            }
        }
        return d / i;
    }

    public double meanReciprocalRank() {
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < numCategories(); i2++) {
            for (int i3 = 0; i3 < numCategories(); i3++) {
                int i4 = this.mRankCounts[i2][i3];
                if (i4 != 0) {
                    i += i4;
                    d += i4 / (1.0d + i3);
                }
            }
        }
        return d / i;
    }

    public double averageRank(String str, String str2) {
        validateCategory(str);
        validateCategory(str2);
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < this.mReferenceCategories.size(); i2++) {
            if (this.mReferenceCategories.get(i2).equals(str)) {
                d += getRank((RankedClassification) this.mClassifications.get(i2), str2);
                i++;
            }
        }
        return d / i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int categoryToIndex(String str) {
        int index = confusionMatrix().getIndex(str);
        if (index < 0) {
            throw new IllegalArgumentException("Unknown category=" + str);
        }
        return index;
    }

    int getRank(RankedClassification rankedClassification, String str) {
        for (int i = 0; i < rankedClassification.size(); i++) {
            if (rankedClassification.category(i).equals(str)) {
                return i;
            }
        }
        return this.mCategories.length - 1;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // com.aliasi.classify.BaseClassifierEvaluator
    public void baseToString(StringBuilder sb) {
        super.baseToString(sb);
        sb.append("Average Reference Rank=" + averageRankReference() + "\n");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // com.aliasi.classify.BaseClassifierEvaluator
    public void oneVsAllToString(StringBuilder sb, String str, int i) {
        super.oneVsAllToString(sb, str, i);
        sb.append("Rank Histogram=\n");
        appendCategoryLine(sb);
        for (int i2 = 0; i2 < numCategories(); i2++) {
            if (i2 > 0) {
                sb.append(',');
            }
            sb.append(this.mRankCounts[i][i2]);
        }
        sb.append("\n");
        sb.append("Average Rank Histogram=\n");
        appendCategoryLine(sb);
        for (int i3 = 0; i3 < numCategories(); i3++) {
            if (i3 > 0) {
                sb.append(',');
            }
            sb.append(averageRank(str, categories()[i3]));
        }
        sb.append("\n");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void appendCategoryLine(StringBuilder sb) {
        sb.append("  ");
        for (int i = 0; i < numCategories(); i++) {
            if (i > 0) {
                sb.append(',');
            }
            sb.append(categories()[i]);
        }
        sb.append("\n  ");
    }
}
