package com.aliasi.classify;

import com.aliasi.lm.LanguageModel;
import com.aliasi.stats.MultivariateDistribution;
import com.aliasi.util.ScoredObject;
import com.aliasi.util.Strings;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;

/* loaded from: input_file:com/aliasi/classify/LMClassifier.class */
public class LMClassifier<L extends LanguageModel, M extends MultivariateDistribution> implements JointClassifier<CharSequence> {
    final L[] mLanguageModels;
    final M mCategoryDistribution;
    final HashMap<String, L> mCategoryToModel;
    final String[] mCategories;

    public LMClassifier(String[] strArr, L[] lArr, M m) {
        HashSet hashSet = new HashSet();
        for (String str : strArr) {
            if (!hashSet.add(str)) {
                throw new IllegalArgumentException("Duplicate category=" + str);
            }
        }
        if (strArr.length < 2) {
            throw new IllegalArgumentException("Require at least two categories. Found categories.length=" + strArr.length);
        }
        if (strArr.length != m.numDimensions()) {
            throw new IllegalArgumentException("Require same number of categories as dimensions. Found categories.length=" + strArr.length + " Found categoryDistribution.numDimensions()=" + m.numDimensions());
        }
        this.mCategories = strArr;
        if (strArr.length != lArr.length) {
            throw new IllegalArgumentException("Categories and language models must be same length. Found categories length=" + strArr.length + " Found language models length=" + lArr.length);
        }
        this.mLanguageModels = lArr;
        this.mCategoryDistribution = m;
        this.mCategoryToModel = new HashMap<>();
        for (int i = 0; i < strArr.length; i++) {
            this.mCategoryToModel.put(strArr[i], lArr[i]);
        }
    }

    public String[] categories() {
        return (String[]) this.mCategories.clone();
    }

    public L languageModel(String str) {
        for (int i = 0; i < this.mCategories.length; i++) {
            if (str.equals(this.mCategories[i])) {
                return this.mLanguageModels[i];
            }
        }
        throw new IllegalArgumentException("Category not known.  Category=" + str);
    }

    public M categoryDistribution() {
        return this.mCategoryDistribution;
    }

    @Override // com.aliasi.classify.BaseClassifier
    public JointClassification classify(CharSequence charSequence) {
        if (charSequence instanceof CharSequence) {
            return classifyJoint(Strings.toCharArray(charSequence), 0, charSequence.length());
        }
        throw new IllegalArgumentException("LM Classification requires CharSequence input. Found class=" + (charSequence == null ? null : charSequence.getClass()));
    }

    public JointClassification classifyJoint(char[] cArr, int i, int i2) {
        Strings.checkArgsStartEnd(cArr, i, i2);
        ScoredObject[] scoredObjectArr = new ScoredObject[categories().length];
        for (int i3 = 0; i3 < categories().length; i3++) {
            String str = categories()[i3];
            scoredObjectArr[i3] = new ScoredObject(str, this.mLanguageModels[i3].log2Estimate(new String(cArr, i, i2 - i)) + this.mCategoryDistribution.log2Probability(str));
        }
        return toJointClassification(scoredObjectArr, (i2 - i) + 2);
    }

    static JointClassification toJointClassification(ScoredObject<String>[] scoredObjectArr, double d) {
        Arrays.sort(scoredObjectArr, ScoredObject.reverseComparator());
        String[] strArr = new String[scoredObjectArr.length];
        double[] dArr = new double[scoredObjectArr.length];
        double[] dArr2 = new double[scoredObjectArr.length];
        for (int i = 0; i < scoredObjectArr.length; i++) {
            strArr[i] = scoredObjectArr[i].getObject();
            dArr[i] = scoredObjectArr[i].score();
            dArr2[i] = dArr[i] / d;
        }
        return new JointClassification(strArr, dArr2, dArr);
    }
}
