/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.classification;

import java.io.Closeable;
import java.io.IOException;
import java.util.LinkedList;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;

public class SimpleNaiveBayesClassifier
implements Classifier<BytesRef> {
    private AtomicReader atomicReader;
    private String[] textFieldNames;
    private String classFieldName;
    private int docsWithClassSize;
    private Analyzer analyzer;
    private IndexSearcher indexSearcher;
    private Query query;

    @Override
    public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
        this.atomicReader = atomicReader;
        this.indexSearcher = new IndexSearcher((IndexReader)this.atomicReader);
        this.textFieldNames = new String[]{textFieldName};
        this.classFieldName = classFieldName;
        this.analyzer = analyzer;
        this.docsWithClassSize = this.countDocsWithClass();
        this.query = query;
    }

    @Override
    public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
        this.train(atomicReader, textFieldName, classFieldName, analyzer, null);
    }

    @Override
    public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
        this.atomicReader = atomicReader;
        this.indexSearcher = new IndexSearcher((IndexReader)this.atomicReader);
        this.textFieldNames = textFieldNames;
        this.classFieldName = classFieldName;
        this.analyzer = analyzer;
        this.docsWithClassSize = this.countDocsWithClass();
        this.query = query;
    }

    private int countDocsWithClass() throws IOException {
        int docCount = MultiFields.getTerms((IndexReader)this.atomicReader, (String)this.classFieldName).getDocCount();
        if (docCount == -1) {
            TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
            BooleanQuery q = new BooleanQuery();
            q.add(new BooleanClause((Query)new WildcardQuery(new Term(this.classFieldName, String.valueOf('*'))), BooleanClause.Occur.MUST));
            if (this.query != null) {
                q.add(this.query, BooleanClause.Occur.MUST);
            }
            this.indexSearcher.search((Query)q, (Collector)totalHitCountCollector);
            docCount = totalHitCountCollector.getTotalHits();
        }
        return docCount;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private String[] tokenizeDoc(String doc) throws IOException {
        LinkedList<String> result = new LinkedList<String>();
        for (String textFieldName : this.textFieldNames) {
            TokenStream tokenStream = this.analyzer.tokenStream(textFieldName, doc);
            try {
                CharTermAttribute charTermAttribute = (CharTermAttribute)tokenStream.addAttribute(CharTermAttribute.class);
                tokenStream.reset();
                while (tokenStream.incrementToken()) {
                    result.add(charTermAttribute.toString());
                }
                tokenStream.end();
            }
            catch (Throwable throwable) {
                IOUtils.closeWhileHandlingException((Closeable[])new Closeable[]{tokenStream});
                throw throwable;
            }
            IOUtils.closeWhileHandlingException((Closeable[])new Closeable[]{tokenStream});
        }
        return result.toArray(new String[result.size()]);
    }

    @Override
    public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
        BytesRef next;
        if (this.atomicReader == null) {
            throw new IOException("You must first call Classifier#train");
        }
        double max = 0.0;
        BytesRef foundClass = new BytesRef();
        Terms terms = MultiFields.getTerms((IndexReader)this.atomicReader, (String)this.classFieldName);
        TermsEnum termsEnum = terms.iterator(null);
        String[] tokenizedDoc = this.tokenizeDoc(inputDocument);
        while ((next = termsEnum.next()) != null) {
            double clVal = this.calculatePrior(next) * this.calculateLikelihood(tokenizedDoc, next);
            if (!(clVal > max)) continue;
            max = clVal;
            foundClass = BytesRef.deepCopyOf((BytesRef)next);
        }
        return new ClassificationResult<BytesRef>(foundClass, max);
    }

    private double calculateLikelihood(String[] tokenizedDoc, BytesRef c) throws IOException {
        double result = 1.0;
        for (String word : tokenizedDoc) {
            int hits = this.getWordFreqForClass(word, c);
            double num = hits + 1;
            double den = this.getTextTermFreqForClass(c) + (double)this.docsWithClassSize;
            double wordProbability = num / den;
            result *= wordProbability;
        }
        return result;
    }

    private double getTextTermFreqForClass(BytesRef c) throws IOException {
        double avgNumberOfUniqueTerms = 0.0;
        for (String textFieldName : this.textFieldNames) {
            Terms terms = MultiFields.getTerms((IndexReader)this.atomicReader, (String)textFieldName);
            long numPostings = terms.getSumDocFreq();
            avgNumberOfUniqueTerms += (double)numPostings / (double)terms.getDocCount();
        }
        int docsWithC = this.atomicReader.docFreq(new Term(this.classFieldName, c));
        return avgNumberOfUniqueTerms * (double)docsWithC;
    }

    private int getWordFreqForClass(String word, BytesRef c) throws IOException {
        BooleanQuery booleanQuery = new BooleanQuery();
        BooleanQuery subQuery = new BooleanQuery();
        for (String textFieldName : this.textFieldNames) {
            subQuery.add(new BooleanClause((Query)new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
        }
        booleanQuery.add(new BooleanClause((Query)subQuery, BooleanClause.Occur.MUST));
        booleanQuery.add(new BooleanClause((Query)new TermQuery(new Term(this.classFieldName, c)), BooleanClause.Occur.MUST));
        if (this.query != null) {
            booleanQuery.add(this.query, BooleanClause.Occur.MUST);
        }
        TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
        this.indexSearcher.search((Query)booleanQuery, (Collector)totalHitCountCollector);
        return totalHitCountCollector.getTotalHits();
    }

    private double calculatePrior(BytesRef currentClass) throws IOException {
        return (double)this.docCount(currentClass) / (double)this.docsWithClassSize;
    }

    private int docCount(BytesRef countedClass) throws IOException {
        return this.atomicReader.docFreq(new Term(this.classFieldName, countedClass));
    }
}

