/*
 * Decompiled with CFR 0.152.
 */
package com.hankcs.hanlp.classification.classifiers;

import com.hankcs.hanlp.classification.classifiers.AbstractClassifier;
import com.hankcs.hanlp.classification.corpus.Document;
import com.hankcs.hanlp.classification.corpus.IDataSet;
import com.hankcs.hanlp.classification.features.BaseFeatureData;
import com.hankcs.hanlp.classification.features.ChiSquareFeatureExtractor;
import com.hankcs.hanlp.classification.models.AbstractModel;
import com.hankcs.hanlp.classification.models.NaiveBayesModel;
import com.hankcs.hanlp.classification.utilities.MathUtility;
import com.hankcs.hanlp.classification.utilities.Predefine;
import com.hankcs.hanlp.collection.trie.bintrie.BinTrie;
import java.util.Map;
import java.util.TreeMap;

public class NaiveBayesClassifier
extends AbstractClassifier {
    private NaiveBayesModel model;

    public NaiveBayesClassifier(NaiveBayesModel naiveBayesModel) {
        this.model = naiveBayesModel;
    }

    public NaiveBayesClassifier() {
        this(null);
    }

    public NaiveBayesModel getNaiveBayesModel() {
        return this.model;
    }

    @Override
    public void train(IDataSet dataSet) {
        Predefine.logger.out("\u539f\u59cb\u6570\u636e\u96c6\u5927\u5c0f:%d\n", dataSet.size());
        BaseFeatureData featureData = this.selectFeatures(dataSet);
        this.model = new NaiveBayesModel();
        this.model.n = featureData.n;
        this.model.d = featureData.featureCategoryJointCount.length;
        this.model.c = featureData.categoryCounts.length;
        this.model.logPriors = new TreeMap<Integer, Double>();
        for (int category = 0; category < featureData.categoryCounts.length; ++category) {
            int sumCategory = featureData.categoryCounts[category];
            this.model.logPriors.put(category, Math.log((double)sumCategory / (double)this.model.n));
        }
        TreeMap<Integer, Double> featureOccurrencesInCategory = new TreeMap<Integer, Double>();
        for (Integer category : this.model.logPriors.keySet()) {
            Double featureOccSum = 0.0;
            for (int feature = 0; feature < featureData.featureCategoryJointCount.length; ++feature) {
                featureOccSum = featureOccSum + (double)featureData.featureCategoryJointCount[feature][category];
            }
            featureOccurrencesInCategory.put(category, featureOccSum);
        }
        for (Integer category : this.model.logPriors.keySet()) {
            for (int feature = 0; feature < featureData.featureCategoryJointCount.length; ++feature) {
                int[] featureCategoryCounts = featureData.featureCategoryJointCount[feature];
                int count = featureCategoryCounts[category];
                double logLikelihood = Math.log(((double)count + 1.0) / ((Double)featureOccurrencesInCategory.get(category) + (double)this.model.d));
                if (!this.model.logLikelihoods.containsKey(feature)) {
                    this.model.logLikelihoods.put(feature, new TreeMap());
                }
                this.model.logLikelihoods.get(feature).put(category, logLikelihood);
            }
        }
        Predefine.logger.out("\u8d1d\u53f6\u65af\u7edf\u8ba1\u7ed3\u675f\n", new Object[0]);
        this.model.catalog = dataSet.getCatalog().toArray();
        this.model.tokenizer = dataSet.getTokenizer();
        this.model.wordIdTrie = featureData.wordIdTrie;
    }

    @Override
    public AbstractModel getModel() {
        return this.model;
    }

    @Override
    public Map<String, Double> predict(String text) throws IllegalArgumentException, IllegalStateException {
        if (this.model == null) {
            throw new IllegalStateException("\u672a\u8bad\u7ec3\u6a21\u578b\uff01\u65e0\u6cd5\u6267\u884c\u9884\u6d4b\uff01");
        }
        if (text == null) {
            throw new IllegalArgumentException("\u53c2\u6570 text == null");
        }
        Document doc = new Document(this.model.wordIdTrie, this.model.tokenizer.segment(text));
        return this.predict(doc);
    }

    @Override
    public double[] categorize(Document document) throws IllegalArgumentException, IllegalStateException {
        double[] predictionScores = new double[this.model.catalog.length];
        for (Map.Entry<Integer, Double> entry1 : this.model.logPriors.entrySet()) {
            Integer category = entry1.getKey();
            Double logprob = entry1.getValue();
            for (Map.Entry entry2 : document.tfMap.entrySet()) {
                Integer feature = (Integer)entry2.getKey();
                if (!this.model.logLikelihoods.containsKey(feature)) continue;
                Integer occurrences = ((int[])entry2.getValue())[0];
                logprob = logprob + (double)occurrences.intValue() * this.model.logLikelihoods.get(feature).get(category);
            }
            predictionScores[category.intValue()] = logprob;
        }
        if (this.configProbabilityEnabled) {
            MathUtility.normalizeExp(predictionScores);
        }
        return predictionScores;
    }

    protected BaseFeatureData selectFeatures(IDataSet dataSet) {
        ChiSquareFeatureExtractor chiSquareFeatureExtractor = new ChiSquareFeatureExtractor();
        Predefine.logger.start("\u4f7f\u7528\u5361\u65b9\u68c0\u6d4b\u9009\u62e9\u7279\u5f81\u4e2d...", new Object[0]);
        BaseFeatureData featureData = ChiSquareFeatureExtractor.extractBasicFeatureData(dataSet);
        Map<Integer, Double> selectedFeatures = chiSquareFeatureExtractor.chi_square(featureData);
        int[][] featureCategoryJointCount = new int[selectedFeatures.size()][];
        featureData.wordIdTrie = new BinTrie();
        String[] wordIdArray = dataSet.getLexicon().getWordIdArray();
        int p = -1;
        for (Integer feature : selectedFeatures.keySet()) {
            featureCategoryJointCount[++p] = featureData.featureCategoryJointCount[feature];
            featureData.wordIdTrie.put(wordIdArray[feature], (Integer)p);
        }
        Predefine.logger.finish(",\u9009\u4e2d\u7279\u5f81\u6570:%d / %d = %.2f%%\n", featureCategoryJointCount.length, featureData.featureCategoryJointCount.length, (double)featureCategoryJointCount.length / (double)featureData.featureCategoryJointCount.length * 100.0);
        featureData.featureCategoryJointCount = featureCategoryJointCount;
        return featureData;
    }
}

