/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.classification;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.fst.Builder;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.Outputs;
import org.apache.lucene.util.fst.PositiveIntOutputs;
import org.apache.lucene.util.fst.Util;

public class BooleanPerceptronClassifier
implements Classifier<Boolean> {
    private Double threshold;
    private final Integer batchSize;
    private Terms textTerms;
    private Analyzer analyzer;
    private String textFieldName;
    private FST<Long> fst;

    public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
        this.threshold = threshold;
        this.batchSize = batchSize;
    }

    public BooleanPerceptronClassifier() {
        this.batchSize = 1;
    }

    @Override
    public ClassificationResult<Boolean> assignClass(String text) throws IOException {
        if (this.textTerms == null) {
            throw new IOException("You must first call Classifier#train");
        }
        Long output = 0L;
        try (TokenStream tokenStream = this.analyzer.tokenStream(this.textFieldName, text);){
            CharTermAttribute charTermAttribute = (CharTermAttribute)tokenStream.addAttribute(CharTermAttribute.class);
            tokenStream.reset();
            while (tokenStream.incrementToken()) {
                String s = charTermAttribute.toString();
                Long d = (Long)Util.get(this.fst, (BytesRef)new BytesRef((CharSequence)s));
                if (d == null) continue;
                output = output + d;
            }
            tokenStream.end();
        }
        return new ClassificationResult<Boolean>((double)output.longValue() >= this.threshold, output.doubleValue());
    }

    @Override
    public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
        this.train(leafReader, textFieldName, classFieldName, analyzer, null);
    }

    @Override
    public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
        BytesRef textTerm;
        this.textTerms = MultiFields.getTerms((IndexReader)leafReader, (String)textFieldName);
        if (this.textTerms == null) {
            throw new IOException("term vectors need to be available for field " + textFieldName);
        }
        this.analyzer = analyzer;
        this.textFieldName = textFieldName;
        if (this.threshold == null || this.threshold == 0.0) {
            long sumDocFreq = leafReader.getSumDocFreq(textFieldName);
            if (sumDocFreq != -1L) {
                this.threshold = (double)sumDocFreq / 2.0;
            } else {
                throw new IOException("threshold cannot be assigned since term vectors for field " + textFieldName + " do not exist");
            }
        }
        TreeMap<String, Double> weights = new TreeMap<String, Double>();
        TermsEnum reuse = this.textTerms.iterator(null);
        while ((textTerm = reuse.next()) != null) {
            weights.put(textTerm.utf8ToString(), Double.valueOf(reuse.totalTermFreq()));
        }
        this.updateFST(weights);
        IndexSearcher indexSearcher = new IndexSearcher((IndexReader)leafReader);
        int batchCount = 0;
        BooleanQuery q = new BooleanQuery();
        q.add(new BooleanClause((Query)new WildcardQuery(new Term(classFieldName, "*")), BooleanClause.Occur.MUST));
        if (query != null) {
            q.add(new BooleanClause(query, BooleanClause.Occur.MUST));
        }
        for (ScoreDoc scoreDoc : indexSearcher.search((Query)q, (int)Integer.MAX_VALUE).scoreDocs) {
            Document doc = indexSearcher.doc(scoreDoc.doc);
            ClassificationResult<Boolean> classificationResult = this.assignClass(doc.getField(textFieldName).stringValue());
            Boolean assignedClass = classificationResult.getAssignedClass();
            IndexableField field = doc.getField(classFieldName);
            Boolean correctClass = Boolean.valueOf(field.stringValue());
            long modifier = correctClass.compareTo(assignedClass);
            if (modifier != 0L) {
                reuse = this.updateWeights(leafReader, reuse, scoreDoc.doc, assignedClass, weights, modifier, batchCount % this.batchSize == 0);
            }
            ++batchCount;
        }
        weights.clear();
    }

    @Override
    public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
        throw new IOException("training with multiple fields not supported by boolean perceptron classifier");
    }

    private TermsEnum updateWeights(LeafReader leafReader, TermsEnum reuse, int docId, Boolean assignedClass, SortedMap<String, Double> weights, double modifier, boolean updateFST) throws IOException {
        BytesRef term;
        TermsEnum cte = this.textTerms.iterator(reuse);
        Terms terms = leafReader.getTermVector(docId, this.textFieldName);
        if (terms == null) {
            throw new IOException("term vectors must be stored for field " + this.textFieldName);
        }
        TermsEnum termsEnum = terms.iterator(null);
        while ((term = termsEnum.next()) != null) {
            cte.seekExact(term);
            if (assignedClass == null) continue;
            long termFreqLocal = termsEnum.totalTermFreq();
            Long previousValue = (Long)Util.get(this.fst, (BytesRef)term);
            String termString = term.utf8ToString();
            weights.put(termString, (double)previousValue.longValue() + modifier * (double)termFreqLocal);
        }
        if (updateFST) {
            this.updateFST(weights);
        }
        reuse = cte;
        return reuse;
    }

    private void updateFST(SortedMap<String, Double> weights) throws IOException {
        PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
        Builder fstBuilder = new Builder(FST.INPUT_TYPE.BYTE1, (Outputs)outputs);
        BytesRefBuilder scratchBytes = new BytesRefBuilder();
        IntsRefBuilder scratchInts = new IntsRefBuilder();
        for (Map.Entry<String, Double> entry : weights.entrySet()) {
            scratchBytes.copyChars((CharSequence)entry.getKey());
            fstBuilder.add(Util.toIntsRef((BytesRef)scratchBytes.get(), (IntsRefBuilder)scratchInts), (Object)entry.getValue().longValue());
        }
        this.fst = fstBuilder.finish();
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
        return null;
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
        return null;
    }
}

