/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.classification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;

public class SimpleNaiveBayesClassifier
implements Classifier<BytesRef> {
    protected LeafReader leafReader;
    protected String[] textFieldNames;
    protected String classFieldName;
    protected Analyzer analyzer;
    protected IndexSearcher indexSearcher;
    protected Query query;

    @Override
    public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
        this.train(leafReader, textFieldName, classFieldName, analyzer, null);
    }

    @Override
    public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
        this.train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
    }

    @Override
    public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
        this.leafReader = leafReader;
        this.indexSearcher = new IndexSearcher((IndexReader)this.leafReader);
        this.textFieldNames = textFieldNames;
        this.classFieldName = classFieldName;
        this.analyzer = analyzer;
        this.query = query;
    }

    @Override
    public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
        List<ClassificationResult<BytesRef>> doclist = this.assignClassNormalizedList(inputDocument);
        ClassificationResult<BytesRef> retval = null;
        double maxscore = -1.7976931348623157E308;
        for (ClassificationResult<BytesRef> element : doclist) {
            if (!(element.getScore() > maxscore)) continue;
            retval = element;
            maxscore = element.getScore();
        }
        return retval;
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
        List<ClassificationResult<BytesRef>> doclist = this.assignClassNormalizedList(text);
        Collections.sort(doclist);
        return doclist;
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
        List<ClassificationResult<BytesRef>> doclist = this.assignClassNormalizedList(text);
        Collections.sort(doclist);
        return doclist.subList(0, max);
    }

    private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
        BytesRef next;
        if (this.leafReader == null) {
            throw new IOException("You must first call Classifier#train");
        }
        ArrayList<ClassificationResult<BytesRef>> dataList = new ArrayList<ClassificationResult<BytesRef>>();
        Terms terms = MultiFields.getTerms((IndexReader)this.leafReader, (String)this.classFieldName);
        TermsEnum termsEnum = terms.iterator(null);
        String[] tokenizedDoc = this.tokenizeDoc(inputDocument);
        int docsWithClassSize = this.countDocsWithClass();
        while ((next = termsEnum.next()) != null) {
            double clVal = this.calculateLogPrior(next, docsWithClassSize) + this.calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize);
            dataList.add(new ClassificationResult<BytesRef>(BytesRef.deepCopyOf((BytesRef)next), clVal));
        }
        ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<ClassificationResult<BytesRef>>();
        if (!dataList.isEmpty()) {
            Collections.sort(dataList);
            double smax = ((ClassificationResult)dataList.get(0)).getScore();
            double sumLog = 0.0;
            for (ClassificationResult classificationResult : dataList) {
                sumLog += Math.exp(classificationResult.getScore() - smax);
            }
            double loga = smax;
            loga += Math.log(sumLog);
            for (ClassificationResult classificationResult : dataList) {
                returnList.add(new ClassificationResult(classificationResult.getAssignedClass(), Math.exp(classificationResult.getScore() - loga)));
            }
        }
        return returnList;
    }

    protected int countDocsWithClass() throws IOException {
        int docCount = MultiFields.getTerms((IndexReader)this.leafReader, (String)this.classFieldName).getDocCount();
        if (docCount == -1) {
            TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
            BooleanQuery q = new BooleanQuery();
            q.add(new BooleanClause((Query)new WildcardQuery(new Term(this.classFieldName, String.valueOf('*'))), BooleanClause.Occur.MUST));
            if (this.query != null) {
                q.add(this.query, BooleanClause.Occur.MUST);
            }
            this.indexSearcher.search((Query)q, (Collector)totalHitCountCollector);
            docCount = totalHitCountCollector.getTotalHits();
        }
        return docCount;
    }

    protected String[] tokenizeDoc(String doc) throws IOException {
        LinkedList<String> result = new LinkedList<String>();
        for (String textFieldName : this.textFieldNames) {
            try (TokenStream tokenStream = this.analyzer.tokenStream(textFieldName, doc);){
                CharTermAttribute charTermAttribute = (CharTermAttribute)tokenStream.addAttribute(CharTermAttribute.class);
                tokenStream.reset();
                while (tokenStream.incrementToken()) {
                    result.add(charTermAttribute.toString());
                }
                tokenStream.end();
            }
        }
        return result.toArray(new String[result.size()]);
    }

    private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c, int docsWithClassSize) throws IOException {
        double result = 0.0;
        for (String word : tokenizedDoc) {
            int hits = this.getWordFreqForClass(word, c);
            double num = hits + 1;
            double den = this.getTextTermFreqForClass(c) + (double)docsWithClassSize;
            double wordProbability = num / den;
            result += Math.log(wordProbability);
        }
        return result;
    }

    private double getTextTermFreqForClass(BytesRef c) throws IOException {
        double avgNumberOfUniqueTerms = 0.0;
        for (String textFieldName : this.textFieldNames) {
            Terms terms = MultiFields.getTerms((IndexReader)this.leafReader, (String)textFieldName);
            long numPostings = terms.getSumDocFreq();
            avgNumberOfUniqueTerms += (double)numPostings / (double)terms.getDocCount();
        }
        int docsWithC = this.leafReader.docFreq(new Term(this.classFieldName, c));
        return avgNumberOfUniqueTerms * (double)docsWithC;
    }

    private int getWordFreqForClass(String word, BytesRef c) throws IOException {
        BooleanQuery booleanQuery = new BooleanQuery();
        BooleanQuery subQuery = new BooleanQuery();
        for (String textFieldName : this.textFieldNames) {
            subQuery.add(new BooleanClause((Query)new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
        }
        booleanQuery.add(new BooleanClause((Query)subQuery, BooleanClause.Occur.MUST));
        booleanQuery.add(new BooleanClause((Query)new TermQuery(new Term(this.classFieldName, c)), BooleanClause.Occur.MUST));
        if (this.query != null) {
            booleanQuery.add(this.query, BooleanClause.Occur.MUST);
        }
        TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
        this.indexSearcher.search((Query)booleanQuery, (Collector)totalHitCountCollector);
        return totalHitCountCollector.getTotalHits();
    }

    private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException {
        return Math.log(this.docCount(currentClass)) - Math.log(docsWithClassSize);
    }

    private int docCount(BytesRef countedClass) throws IOException {
        return this.leafReader.docFreq(new Term(this.classFieldName, countedClass));
    }
}

