package opennlp.tools.authorage;

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import opennlp.tools.ml.authorage.AgeClassifyTrainerFactory;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;

/* loaded from: input_file:opennlp/tools/authorage/AgeClassifyME.class */
public class AgeClassifyME {
    protected AgeClassifyContextGenerator contextGenerator;
    private AgeClassifyFactory factory;
    private AgeClassifyModel model;

    public AgeClassifyME(AgeClassifyModel ageClassifyModel) {
        this.model = ageClassifyModel;
        this.factory = ageClassifyModel.getFactory();
        this.contextGenerator = new AgeClassifyContextGenerator(this.factory.getFeatureGenerators());
    }

    public String getBestCategory(double[] dArr) {
        return this.model.getMaxentModel().getBestOutcome(dArr);
    }

    public int getNumCategories() {
        return this.model.getMaxentModel().getNumOutcomes();
    }

    public String getCategory(int i) {
        return this.model.getMaxentModel().getOutcome(i);
    }

    public int getIndex(String str) {
        return this.model.getMaxentModel().getIndex(str);
    }

    public double[] getProbabilities(String[] strArr) {
        return this.model.getMaxentModel().eval(this.contextGenerator.getContext(strArr));
    }

    public double[] getProbabilities(String str) {
        return getProbabilities(this.factory.getTokenizer().tokenize(str));
    }

    public String predict(String str) {
        return getBestCategory(getProbabilities(str));
    }

    public Map<String, Double> scoreMap(String str) {
        HashMap hashMap = new HashMap();
        double[] probabilities = getProbabilities(str);
        int numCategories = getNumCategories();
        for (int i = 0; i < numCategories; i++) {
            String category = getCategory(i);
            hashMap.put(category, Double.valueOf(probabilities[getIndex(category)]));
        }
        return hashMap;
    }

    public SortedMap<Double, Set<String>> sortedScoreMap(String str) {
        TreeMap treeMap = new TreeMap();
        double[] probabilities = getProbabilities(str);
        int numCategories = getNumCategories();
        for (int i = 0; i < numCategories; i++) {
            String category = getCategory(i);
            double d = probabilities[getIndex(category)];
            if (treeMap.containsKey(Double.valueOf(d))) {
                ((Set) treeMap.get(Double.valueOf(d))).add(category);
            } else {
                HashSet hashSet = new HashSet();
                hashSet.add(category);
                treeMap.put(Double.valueOf(d), hashSet);
            }
        }
        return treeMap;
    }

    public static AgeClassifyModel train(String str, ObjectStream<AuthorAgeSample> objectStream, TrainingParameters trainingParameters, AgeClassifyFactory ageClassifyFactory) throws IOException {
        HashMap hashMap = new HashMap();
        AgeClassifyTrainerFactory.getTrainerType(trainingParameters.getSettings());
        return new AgeClassifyModel(str, AgeClassifyTrainerFactory.getEventTrainer(trainingParameters.getSettings(), hashMap).train(new AgeClassifyEventStream(objectStream, ageClassifyFactory.createContextGenerator())), new HashMap(), ageClassifyFactory);
    }
}
