/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.classification.utils;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermRangeQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.NamedThreadFactory;

public class ConfusionMatrixGenerator {
    private ConfusionMatrixGenerator() {
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static <T> ConfusionMatrix getConfusionMatrix(IndexReader reader, Classifier<T> classifier, String classFieldName, String textFieldName, long timeoutMilliseconds) throws IOException {
        ExecutorService executorService = Executors.newFixedThreadPool(1, (ThreadFactory)new NamedThreadFactory("confusion-matrix-gen-"));
        try {
            HashMap<String, Map<String, Long>> counts = new HashMap<String, Map<String, Long>>();
            IndexSearcher indexSearcher = new IndexSearcher(reader);
            TopDocs topDocs = indexSearcher.search((Query)new TermRangeQuery(classFieldName, null, null, true, true), Integer.MAX_VALUE);
            double time = 0.0;
            int counter = 0;
            for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
                if (timeoutMilliseconds > 0L && time >= (double)timeoutMilliseconds) break;
                Document doc = reader.document(scoreDoc.doc);
                Object[] correctAnswers = doc.getValues(classFieldName);
                if (correctAnswers == null || correctAnswers.length <= 0) continue;
                Arrays.sort(correctAnswers);
                String text = doc.get(textFieldName);
                if (text == null) continue;
                try {
                    Object assignedClass;
                    long start = System.currentTimeMillis();
                    ClassificationResult result = executorService.submit(() -> classifier.assignClass(text)).get(5L, TimeUnit.SECONDS);
                    long end = System.currentTimeMillis();
                    time += (double)(end - start);
                    if (result == null || (assignedClass = result.getAssignedClass()) == null) continue;
                    ++counter;
                    String classified = assignedClass instanceof BytesRef ? ((BytesRef)assignedClass).utf8ToString() : assignedClass.toString();
                    Object correctAnswer = Arrays.binarySearch(correctAnswers, classified) >= 0 ? classified : correctAnswers[0];
                    HashMap<String, Long> stringLongMap = (HashMap<String, Long>)counts.get(correctAnswer);
                    if (stringLongMap != null) {
                        Long aLong = (Long)stringLongMap.get(classified);
                        if (aLong != null) {
                            stringLongMap.put(classified, aLong + 1L);
                            continue;
                        }
                        stringLongMap.put(classified, 1L);
                        continue;
                    }
                    stringLongMap = new HashMap<String, Long>();
                    stringLongMap.put(classified, 1L);
                    counts.put((String)correctAnswer, stringLongMap);
                }
                catch (TimeoutException timeoutException) {
                    time += 5000.0;
                }
                catch (InterruptedException | ExecutionException executionException) {
                    throw new RuntimeException(executionException);
                }
            }
            ConfusionMatrix confusionMatrix = new ConfusionMatrix(counts, time / (double)counter, counter);
            return confusionMatrix;
        }
        finally {
            executorService.shutdown();
        }
    }

    public static class ConfusionMatrix {
        private final Map<String, Map<String, Long>> linearizedMatrix;
        private final double avgClassificationTime;
        private final int numberOfEvaluatedDocs;
        private double accuracy = -1.0;

        private ConfusionMatrix(Map<String, Map<String, Long>> linearizedMatrix, double avgClassificationTime, int numberOfEvaluatedDocs) {
            this.linearizedMatrix = linearizedMatrix;
            this.avgClassificationTime = avgClassificationTime;
            this.numberOfEvaluatedDocs = numberOfEvaluatedDocs;
        }

        public Map<String, Map<String, Long>> getLinearizedMatrix() {
            return Collections.unmodifiableMap(this.linearizedMatrix);
        }

        public double getPrecision(String klass) {
            Map<String, Long> classifications = this.linearizedMatrix.get(klass);
            double tp = 0.0;
            double den = 0.0;
            if (classifications != null) {
                for (Map.Entry<String, Long> entry : classifications.entrySet()) {
                    if (!klass.equals(entry.getKey())) continue;
                    tp += (double)entry.getValue().longValue();
                }
                for (Map map : this.linearizedMatrix.values()) {
                    if (!map.containsKey(klass)) continue;
                    den += (double)((Long)map.get(klass)).longValue();
                }
            }
            return tp > 0.0 ? tp / den : 0.0;
        }

        public double getRecall(String klass) {
            Map<String, Long> classifications = this.linearizedMatrix.get(klass);
            double tp = 0.0;
            double fn = 0.0;
            if (classifications != null) {
                for (Map.Entry<String, Long> entry : classifications.entrySet()) {
                    if (klass.equals(entry.getKey())) {
                        tp += (double)entry.getValue().longValue();
                        continue;
                    }
                    fn += (double)entry.getValue().longValue();
                }
            }
            return tp + fn > 0.0 ? tp / (tp + fn) : 0.0;
        }

        public double getF1Measure(String klass) {
            double recall = this.getRecall(klass);
            double precision = this.getPrecision(klass);
            return precision > 0.0 && recall > 0.0 ? 2.0 * precision * recall / (precision + recall) : 0.0;
        }

        public double getF1Measure() {
            double recall = this.getRecall();
            double precision = this.getPrecision();
            return precision > 0.0 && recall > 0.0 ? 2.0 * precision * recall / (precision + recall) : 0.0;
        }

        public double getAccuracy() {
            if (this.accuracy == -1.0) {
                double tp = 0.0;
                double tn = 0.0;
                double tfp = 0.0;
                double fn = 0.0;
                for (Map.Entry<String, Map<String, Long>> classification : this.linearizedMatrix.entrySet()) {
                    String klass = classification.getKey();
                    for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
                        if (klass.equals(entry.getKey())) {
                            tp += (double)entry.getValue().longValue();
                            continue;
                        }
                        fn += (double)entry.getValue().longValue();
                    }
                    for (Map map : this.linearizedMatrix.values()) {
                        if (map.containsKey(klass)) {
                            tfp += (double)((Long)map.get(klass)).longValue();
                            continue;
                        }
                        tn += 1.0;
                    }
                }
                this.accuracy = (tp + tn) / (tfp + fn + tn);
            }
            return this.accuracy;
        }

        public double getPrecision() {
            double p = 0.0;
            for (Map.Entry<String, Map<String, Long>> classification : this.linearizedMatrix.entrySet()) {
                String klass = classification.getKey();
                p += this.getPrecision(klass);
            }
            return p / (double)this.linearizedMatrix.size();
        }

        public double getRecall() {
            double r = 0.0;
            for (Map.Entry<String, Map<String, Long>> classification : this.linearizedMatrix.entrySet()) {
                String klass = classification.getKey();
                r += this.getRecall(klass);
            }
            return r / (double)this.linearizedMatrix.size();
        }

        public String toString() {
            return "ConfusionMatrix{linearizedMatrix=" + this.linearizedMatrix + ", avgClassificationTime=" + this.avgClassificationTime + ", numberOfEvaluatedDocs=" + this.numberOfEvaluatedDocs + "}";
        }

        public double getAvgClassificationTime() {
            return this.avgClassificationTime;
        }

        public int getNumberOfEvaluatedDocs() {
            return this.numberOfEvaluatedDocs;
        }
    }
}

