/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.function.BiFunction;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.classification.SoftClassifier;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.validation.Bag;
import smile.validation.ClassificationMetrics;
import smile.validation.ClassificationValidations;
import smile.validation.metric.AUC;
import smile.validation.metric.Accuracy;
import smile.validation.metric.ConfusionMatrix;
import smile.validation.metric.CrossEntropy;
import smile.validation.metric.Error;
import smile.validation.metric.FScore;
import smile.validation.metric.LogLoss;
import smile.validation.metric.MatthewsCorrelation;
import smile.validation.metric.Precision;
import smile.validation.metric.Sensitivity;
import smile.validation.metric.Specificity;

public class ClassificationValidation<M>
implements Serializable {
    private static final long serialVersionUID = 2L;
    public final M model;
    public final int[] truth;
    public final int[] prediction;
    public final double[][] posteriori;
    public final ConfusionMatrix confusion;
    public final ClassificationMetrics metrics;

    public ClassificationValidation(M model, int[] truth, int[] prediction, double fitTime, double scoreTime) {
        this(model, truth, prediction, null, fitTime, scoreTime);
    }

    public ClassificationValidation(M model, int[] truth, int[] prediction, double[][] posteriori, double fitTime, double scoreTime) {
        this.model = model;
        this.truth = truth;
        this.prediction = prediction;
        this.posteriori = posteriori;
        this.confusion = ConfusionMatrix.of(truth, prediction);
        int k = MathEx.unique((int[])truth).length;
        if (k == 2) {
            if (posteriori == null) {
                this.metrics = new ClassificationMetrics(fitTime, scoreTime, truth.length, Error.of(truth, prediction), Accuracy.of(truth, prediction), Sensitivity.of(truth, prediction), Specificity.of(truth, prediction), Precision.of(truth, prediction), FScore.F1.score(truth, prediction), MatthewsCorrelation.of(truth, prediction));
            } else {
                double[] probability = Arrays.stream(posteriori).mapToDouble(p -> p[1]).toArray();
                this.metrics = new ClassificationMetrics(fitTime, scoreTime, truth.length, Error.of(truth, prediction), Accuracy.of(truth, prediction), Sensitivity.of(truth, prediction), Specificity.of(truth, prediction), Precision.of(truth, prediction), FScore.F1.score(truth, prediction), MatthewsCorrelation.of(truth, prediction), AUC.of(truth, probability), LogLoss.of(truth, probability));
            }
        } else {
            this.metrics = posteriori == null ? new ClassificationMetrics(fitTime, scoreTime, truth.length, Error.of(truth, prediction), Accuracy.of(truth, prediction)) : new ClassificationMetrics(fitTime, scoreTime, truth.length, Error.of(truth, prediction), Accuracy.of(truth, prediction), CrossEntropy.of(truth, posteriori));
        }
    }

    public String toString() {
        return this.metrics.toString();
    }

    public static <T, M extends Classifier<T>> ClassificationValidation<M> of(T[] x, int[] y, T[] testx, int[] testy, BiFunction<T[], int[], M> trainer) {
        int k = MathEx.unique((int[])y).length;
        long start = System.nanoTime();
        Classifier model = (Classifier)trainer.apply((T[][])x, y);
        double fitTime = (double)(System.nanoTime() - start) / 1000000.0;
        if (model instanceof SoftClassifier) {
            start = System.nanoTime();
            double[][] posteriori = new double[testx.length][k];
            int[] prediction = ((SoftClassifier)model).predict(testx, posteriori);
            double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
            return new ClassificationValidation<Classifier>(model, testy, prediction, posteriori, fitTime, scoreTime);
        }
        start = System.nanoTime();
        int[] prediction = model.predict(testx);
        double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
        return new ClassificationValidation<Classifier>(model, testy, prediction, fitTime, scoreTime);
    }

    public static <T, M extends Classifier<T>> ClassificationValidations<M> of(Bag[] bags, T[] x, int[] y, BiFunction<T[], int[], M> trainer) {
        ArrayList rounds = new ArrayList(bags.length);
        for (Bag bag : bags) {
            Object[] trainx = MathEx.slice((Object[])x, (int[])bag.samples);
            int[] trainy = MathEx.slice((int[])y, (int[])bag.samples);
            Object[] testx = MathEx.slice((Object[])x, (int[])bag.oob);
            int[] testy = MathEx.slice((int[])y, (int[])bag.oob);
            rounds.add(ClassificationValidation.of(trainx, trainy, testx, testy, trainer));
        }
        return new ClassificationValidations(rounds);
    }

    public static <M extends DataFrameClassifier> ClassificationValidation<M> of(Formula formula, DataFrame train, DataFrame test, BiFunction<Formula, DataFrame, M> trainer) {
        int[] y = formula.y(train).toIntArray();
        int[] testy = formula.y(test).toIntArray();
        int k = MathEx.unique((int[])y).length;
        long start = System.nanoTime();
        DataFrameClassifier model = (DataFrameClassifier)trainer.apply(formula, train);
        double fitTime = (double)(System.nanoTime() - start) / 1000000.0;
        int n = test.nrows();
        if (model instanceof SoftClassifier) {
            start = System.nanoTime();
            int[] prediction = new int[n];
            double[][] posteriori = new double[n][k];
            for (int i = 0; i < n; ++i) {
                prediction[i] = ((SoftClassifier)((Object)model)).predict(test.get(i), posteriori[i]);
            }
            double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
            return new ClassificationValidation<DataFrameClassifier>(model, testy, prediction, posteriori, fitTime, scoreTime);
        }
        start = System.nanoTime();
        int[] prediction = new int[n];
        for (int i = 0; i < n; ++i) {
            prediction[i] = model.predict((Tuple)test.get(i));
        }
        double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
        return new ClassificationValidation<DataFrameClassifier>(model, testy, prediction, fitTime, scoreTime);
    }

    public static <M extends DataFrameClassifier> ClassificationValidations<M> of(Bag[] bags, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, M> trainer) {
        ArrayList rounds = new ArrayList(bags.length);
        for (Bag bag : bags) {
            rounds.add(ClassificationValidation.of(formula, data.of(bag.samples), data.of(bag.oob), trainer));
        }
        return new ClassificationValidations(rounds);
    }
}

