/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.classification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.util.BytesRef;

public class CachingNaiveBayesClassifier
extends SimpleNaiveBayesClassifier {
    private ArrayList<BytesRef> cclasses = new ArrayList();
    private Map<String, Map<BytesRef, Integer>> termCClassHitCache = new HashMap<String, Map<BytesRef, Integer>>();
    private Map<BytesRef, Double> classTermFreq = new HashMap<BytesRef, Double>();
    private boolean justCachedTerms;
    private int docsWithClassSize;

    @Override
    public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
        this.train(leafReader, textFieldName, classFieldName, analyzer, null);
    }

    @Override
    public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
        this.train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
    }

    @Override
    public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
        super.train(leafReader, textFieldNames, classFieldName, analyzer, query);
        this.reInitCache(0, true);
    }

    private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
        if (this.leafReader == null) {
            throw new IOException("You must first call Classifier#train");
        }
        String[] tokenizedDoc = this.tokenizeDoc(inputDocument);
        List<ClassificationResult<BytesRef>> dataList = this.calculateLogLikelihood(tokenizedDoc);
        ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<ClassificationResult<BytesRef>>();
        if (!dataList.isEmpty()) {
            Collections.sort(dataList);
            double smax = dataList.get(0).getScore();
            double sumLog = 0.0;
            for (ClassificationResult<BytesRef> cr : dataList) {
                sumLog += Math.exp(cr.getScore() - smax);
            }
            double loga = smax;
            loga += Math.log(sumLog);
            for (ClassificationResult<BytesRef> cr : dataList) {
                returnList.add(new ClassificationResult<BytesRef>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga)));
            }
        }
        return returnList;
    }

    private List<ClassificationResult<BytesRef>> calculateLogLikelihood(String[] tokenizedDoc) throws IOException {
        ArrayList<ClassificationResult<BytesRef>> ret = new ArrayList<ClassificationResult<BytesRef>>();
        for (BytesRef cclass : this.cclasses) {
            ClassificationResult<BytesRef> cr = new ClassificationResult<BytesRef>(cclass, 0.0);
            ret.add(cr);
        }
        for (String word : tokenizedDoc) {
            Map<BytesRef, Integer> hitsInClasses = this.getWordFreqForClassess(word);
            block2: for (BytesRef cclass : this.cclasses) {
                Integer hitsI = hitsInClasses.get(cclass);
                int hits = 0;
                if (hitsI != null) {
                    hits = hitsI;
                }
                double num = hits + 1;
                double den = this.classTermFreq.get(cclass) + (double)this.docsWithClassSize;
                double wordProbability = num / den;
                for (ClassificationResult<BytesRef> cr : ret) {
                    if (!cr.getAssignedClass().equals((Object)cclass)) continue;
                    cr.setScore(cr.getScore() + Math.log(wordProbability));
                    continue block2;
                }
            }
        }
        return ret;
    }

    private Map<BytesRef, Integer> getWordFreqForClassess(String word) throws IOException {
        Map<BytesRef, Integer> insertPoint = this.termCClassHitCache.get(word);
        if (insertPoint != null && !insertPoint.isEmpty()) {
            return insertPoint;
        }
        ConcurrentHashMap<BytesRef, Integer> searched = new ConcurrentHashMap<BytesRef, Integer>();
        if (insertPoint != null || !this.justCachedTerms) {
            for (BytesRef cclass : this.cclasses) {
                BooleanQuery booleanQuery = new BooleanQuery();
                BooleanQuery subQuery = new BooleanQuery();
                for (String textFieldName : this.textFieldNames) {
                    subQuery.add(new BooleanClause((Query)new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
                }
                booleanQuery.add(new BooleanClause((Query)subQuery, BooleanClause.Occur.MUST));
                booleanQuery.add(new BooleanClause((Query)new TermQuery(new Term(this.classFieldName, cclass)), BooleanClause.Occur.MUST));
                if (this.query != null) {
                    booleanQuery.add(this.query, BooleanClause.Occur.MUST);
                }
                TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
                this.indexSearcher.search((Query)booleanQuery, (Collector)totalHitCountCollector);
                int ret = totalHitCountCollector.getTotalHits();
                if (ret == 0) continue;
                searched.put(cclass, ret);
            }
            if (insertPoint != null) {
                this.termCClassHitCache.put(word, searched);
            }
        }
        return searched;
    }

    public void reInitCache(int minTermOccurrenceInCache, boolean justCachedTerms) throws IOException {
        this.justCachedTerms = justCachedTerms;
        this.docsWithClassSize = this.countDocsWithClass();
        this.termCClassHitCache.clear();
        this.cclasses.clear();
        this.classTermFreq.clear();
        HashMap<String, Long> frequencyMap = new HashMap<String, Long>();
        for (String textFieldName : this.textFieldNames) {
            TermsEnum termsEnum = this.leafReader.terms(textFieldName).iterator();
            while (termsEnum.next() != null) {
                BytesRef term = termsEnum.term();
                String termText = term.utf8ToString();
                long frequency = termsEnum.docFreq();
                Long lastfreq = (Long)frequencyMap.get(termText);
                if (lastfreq != null) {
                    frequency += lastfreq.longValue();
                }
                frequencyMap.put(termText, frequency);
            }
        }
        for (Map.Entry entry : frequencyMap.entrySet()) {
            if ((Long)entry.getValue() <= (long)minTermOccurrenceInCache) continue;
            this.termCClassHitCache.put((String)entry.getKey(), new ConcurrentHashMap());
        }
        Terms terms = MultiFields.getTerms((IndexReader)this.leafReader, (String)this.classFieldName);
        TermsEnum termsEnum = terms.iterator();
        while (termsEnum.next() != null) {
            this.cclasses.add(BytesRef.deepCopyOf((BytesRef)termsEnum.term()));
        }
        for (BytesRef cclass : this.cclasses) {
            double avgNumberOfUniqueTerms = 0.0;
            for (String textFieldName : this.textFieldNames) {
                terms = MultiFields.getTerms((IndexReader)this.leafReader, (String)textFieldName);
                long numPostings = terms.getSumDocFreq();
                avgNumberOfUniqueTerms += (double)numPostings / (double)terms.getDocCount();
            }
            int docsWithC = this.leafReader.docFreq(new Term(this.classFieldName, cclass));
            this.classTermFreq.put(cclass, avgNumberOfUniqueTerms * (double)docsWithC);
        }
    }
}

