/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.AUC2;
import hex.ConfusionMatrix;
import hex.CustomMetric;
import hex.GainsLift;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsSupervised;
import hex.genmodel.GenModel;
import hex.genmodel.utils.DistributionFamily;
import java.util.Arrays;
import java.util.Optional;
import water.Key;
import water.MRTask;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.C8DVolatileChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.MathUtils;

public class ModelMetricsBinomial
extends ModelMetricsSupervised {
    public final AUC2 _auc;
    public final double _logloss;
    public final double _loglikelihood;
    public final double _aic;
    public double _mean_per_class_error;
    public final GainsLift _gainsLift;

    public ModelMetricsBinomial(Model model, Frame frame, long nobs, double mse, String[] domain, double sigma, AUC2 auc, double logloss, double loglikelihood, double aic, GainsLift gainsLift, CustomMetric customMetric) {
        super(model, frame, nobs, mse, domain, sigma, customMetric);
        this._auc = auc;
        this._logloss = logloss;
        this._loglikelihood = loglikelihood;
        this._aic = aic;
        this._gainsLift = gainsLift;
        this._mean_per_class_error = this.cm() == null ? Double.NaN : this.cm().mean_per_class_error();
    }

    public ModelMetricsBinomial(Model model, Frame frame, long nobs, double mse, String[] domain, double sigma, AUC2 auc, double logloss, GainsLift gainsLift, CustomMetric customMetric) {
        this(model, frame, nobs, mse, domain, sigma, auc, logloss, Double.NaN, Double.NaN, gainsLift, customMetric);
    }

    public static ModelMetricsBinomial getFromDKV(Model model, Frame frame) {
        ModelMetrics mm = ModelMetrics.getFromDKV(model, frame);
        if (!(mm instanceof ModelMetricsBinomial)) {
            throw new H2OIllegalArgumentException("Expected to find a Binomial ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(), "Expected to find a ModelMetricsBinomial for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + (mm == null ? null : mm.getClass()));
        }
        return (ModelMetricsBinomial)mm;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        if (this._auc != null) {
            sb.append(" AUC: " + (float)this._auc._auc + "\n");
            sb.append(" pr_auc: " + (float)this._auc.pr_auc() + "\n");
        }
        sb.append(" logloss: " + (float)this._logloss + "\n");
        sb.append(" loglikelihood: " + (float)this._loglikelihood + "\n");
        sb.append(" AIC: " + (float)this._aic + "\n");
        sb.append(" mean_per_class_error: " + (float)this._mean_per_class_error + "\n");
        sb.append(" default threshold: " + (this._auc == null ? 0.5 : (double)((float)this._auc.defaultThreshold())) + "\n");
        if (this.cm() != null) {
            sb.append(" CM: " + this.cm().toASCII());
        }
        if (this._gainsLift != null) {
            sb.append(this._gainsLift);
        }
        return sb.toString();
    }

    public double logloss() {
        return this._logloss;
    }

    public double loglikelihood() {
        return this._loglikelihood;
    }

    public double aic() {
        return this._aic;
    }

    public double mean_per_class_error() {
        return this._mean_per_class_error;
    }

    @Override
    public AUC2 auc_obj() {
        return this._auc;
    }

    @Override
    public ConfusionMatrix cm() {
        if (this._auc == null) {
            return null;
        }
        double[][] cm = this._auc.defaultCM();
        return cm == null ? null : new ConfusionMatrix(cm, this._domain);
    }

    public ConfusionMatrix cm(AUC2.ThresholdCriterion criterion) {
        if (this._auc == null) {
            return null;
        }
        double[][] cm = this._auc.cmByCriterion(criterion);
        return cm == null ? null : new ConfusionMatrix(cm, this._domain);
    }

    public GainsLift gainsLift() {
        return this._gainsLift;
    }

    public double auc() {
        return this.auc_obj()._auc;
    }

    public double pr_auc() {
        return this.auc_obj()._pr_auc;
    }

    public double aucpr() {
        return this.auc_obj()._pr_auc;
    }

    public double lift_top_group() {
        return this.gainsLift().response_rates[0] / this.gainsLift().avg_response_rate;
    }

    public static ModelMetricsBinomial make(Vec targetClassProbs, Vec actualLabels) {
        return ModelMetricsBinomial.make(targetClassProbs, actualLabels, actualLabels.domain());
    }

    public static ModelMetricsBinomial make(Vec targetClassProbs, Vec actualLabels, String[] domain) {
        return ModelMetricsBinomial.make(targetClassProbs, actualLabels, null, domain);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static ModelMetricsBinomial make(Vec targetClassProbs, Vec actualLabels, Vec weights, String[] domain) {
        Scope.enter();
        try {
            Vec labels = actualLabels.toCategoricalVec();
            if (domain == null) {
                domain = labels.domain();
            }
            if (labels == null || targetClassProbs == null) {
                throw new IllegalArgumentException("Missing actualLabels or predictedProbs for binomial metrics!");
            }
            if (!targetClassProbs.isNumeric()) {
                throw new IllegalArgumentException("Predicted probabilities must be numeric per-class probabilities for binomial metrics.");
            }
            if (targetClassProbs.min() < 0.0 || targetClassProbs.max() > 1.0) {
                throw new IllegalArgumentException("Predicted probabilities must be between 0 and 1 for binomial metrics.");
            }
            if (domain.length != 2) {
                throw new IllegalArgumentException("Domain must have 2 class labels, but is " + Arrays.toString(domain) + " for binomial metrics.");
            }
            if ((labels = labels.adaptTo(domain)).cardinality() != 2) {
                throw new IllegalArgumentException("Adapted domain must have 2 class labels, but is " + Arrays.toString(labels.domain()) + " for binomial metrics.");
            }
            Frame fr = new Frame(targetClassProbs);
            fr.add("labels", labels);
            if (weights != null) {
                fr.add("weights", weights);
            }
            MetricBuilderBinomial mb = ((BinomialMetrics)new BinomialMetrics((String[])labels.domain()).doAll((Frame)fr))._mb;
            labels.remove();
            Frame preds = new Frame(targetClassProbs);
            ModelMetricsBinomial mm = (ModelMetricsBinomial)mb.makeModelMetrics(null, fr, preds, fr.vec("labels"), fr.vec("weights"));
            mm._description = "Computed on user-given predictions and labels, using F1-optimal threshold: " + mm.auc_obj().defaultThreshold() + ".";
            ModelMetricsBinomial modelMetricsBinomial = mm;
            return modelMetricsBinomial;
        }
        finally {
            Scope.exit(new Key[0]);
        }
    }

    public static class MetricBuilderBinomial<T extends MetricBuilderBinomial<T>>
    extends ModelMetricsSupervised.MetricBuilderSupervised<T> {
        protected double _logloss;
        protected double _loglikelihood;
        protected AUC2.AUCBuilder _auc = new AUC2.AUCBuilder(400);

        public MetricBuilderBinomial(String[] domain) {
            super(2, domain);
        }

        public double auc() {
            return new AUC2((AUC2.AUCBuilder)this._auc)._auc;
        }

        public double pr_auc() {
            return new AUC2((AUC2.AUCBuilder)this._auc)._pr_auc;
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, Model m) {
            return this.perRow(ds, yact, 1.0, 0.0, m);
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, double w, double o, Model m) {
            boolean quasibinomial;
            if (Float.isNaN(yact[0])) {
                return ds;
            }
            if (ArrayUtils.hasNaNs(ds)) {
                return ds;
            }
            if (w == 0.0 || Double.isNaN(w)) {
                return ds;
            }
            int iact = (int)yact[0];
            boolean bl = quasibinomial = m != null && ((Model.Parameters)m._parms)._distribution == DistributionFamily.quasibinomial;
            if (quasibinomial) {
                if (yact[0] != 0.0f) {
                    iact = this._domain[0].equals(String.valueOf((int)yact[0])) ? 0 : 1;
                }
                this._wY += w * (double)yact[0];
                this._wYY += w * (double)yact[0] * (double)yact[0];
                double err = (double)yact[0] - ds[iact + 1];
                this._sumsqe += w * err * err;
                this._logloss += -w * ((double)yact[0] * Math.log(Math.max(1.0E-15, ds[2])) + (double)(1.0f - yact[0]) * Math.log(Math.max(1.0E-15, ds[1])));
            } else {
                if (iact != 0 && iact != 1) {
                    return ds;
                }
                this._wY += w * (double)iact;
                this._wYY += w * (double)iact * (double)iact;
                double err = 1.0 - ds[iact + 1];
                this._sumsqe += w * err * err;
                this._logloss += w * MathUtils.logloss(err);
            }
            if (m != null && m.isGeneric()) {
                this._loglikelihood += m.likelihood(w, yact[0], ds);
            }
            ++this._count;
            this._wcount += w;
            assert (!Double.isNaN(this._sumsqe));
            this._auc.perRow(ds[2], iact, w);
            return ds;
        }

        @Override
        public void reduce(T mb) {
            super.reduce(mb);
            this._logloss += ((MetricBuilderBinomial)mb)._logloss;
            this._loglikelihood += ((MetricBuilderBinomial)mb)._loglikelihood;
            this._auc.reduce(((MetricBuilderBinomial)mb)._auc);
        }

        @Override
        public ModelMetrics makeModelMetrics(Model m, Frame f, Frame frameWithWeights, Frame preds) {
            Vec resp = null;
            Vec weight = null;
            if (this._wcount > 0.0 && preds != null) {
                if (frameWithWeights == null) {
                    frameWithWeights = f;
                }
                Vec vec = resp = m == null && frameWithWeights.vec(f.numCols() - 1).isCategorical() ? frameWithWeights.vec(f.numCols() - 1) : frameWithWeights.vec(((Model.Parameters)m._parms)._response_column);
                if (resp != null) {
                    weight = m == null ? null : frameWithWeights.vec(((Model.Parameters)m._parms)._weights_column);
                }
            }
            return this.makeModelMetrics(m, f, preds, resp, weight);
        }

        private ModelMetrics makeModelMetrics(Model m, Frame f, Frame preds, Vec resp, Vec weight) {
            Optional<GainsLift> optionalGainsLift;
            GainsLift gl = null;
            if (this._wcount > 0.0 && preds != null && resp != null && (optionalGainsLift = this.calculateGainsLift(m, preds, resp, weight)).isPresent()) {
                gl = optionalGainsLift.get();
            }
            return this.makeModelMetrics(m, f, gl);
        }

        private ModelMetrics makeModelMetrics(Model m, Frame f, GainsLift gl) {
            AUC2 auc;
            double mse = Double.NaN;
            double loglikelihood = Double.NaN;
            double aic = Double.NaN;
            double logloss = Double.NaN;
            double sigma = Double.NaN;
            if (this._wcount > 0.0) {
                sigma = this.weightedSigma();
                mse = this._sumsqe / this._wcount;
                logloss = this._logloss / this._wcount;
                if (m != null && m.getClass().toString().contains("Generic")) {
                    loglikelihood = -1.0 * this._loglikelihood;
                    aic = m.aic(loglikelihood);
                }
                auc = new AUC2(this._auc);
            } else {
                auc = new AUC2();
            }
            ModelMetricsBinomial mm = new ModelMetricsBinomial(m, f, this._count, mse, this._domain, sigma, auc, logloss, loglikelihood, aic, gl, this._customMetric);
            if (m != null) {
                m.addModelMetrics(mm);
            }
            return mm;
        }

        private Optional<GainsLift> calculateGainsLift(Model m, Frame preds, Vec resp, Vec weights) {
            GainsLift gl = new GainsLift(preds.lastVec(), resp, weights);
            if (m != null && ((Model.Parameters)m._parms)._gainslift_bins < -1) {
                throw new IllegalArgumentException("Number of G/L bins must be greater or equal than -1.");
            }
            if (m != null && (((Model.Parameters)m._parms)._gainslift_bins > 0 || ((Model.Parameters)m._parms)._gainslift_bins == -1)) {
                gl._groups = ((Model.Parameters)m._parms)._gainslift_bins;
            } else if (m != null && ((Model.Parameters)m._parms)._gainslift_bins == 0) {
                return Optional.empty();
            }
            gl.exec(m != null ? ((Model.Output)m._output)._job : null);
            return Optional.of(gl);
        }

        @Override
        public Frame makePredictionCache(Model m, Vec response) {
            return new Frame(response.makeVolatileDoubles(1));
        }

        @Override
        public void cachePrediction(double[] cdist, Chunk[] chks, int row, int cacheChunkIdx, Model m) {
            assert (cdist.length == 3);
            ((C8DVolatileChunk)chks[cacheChunkIdx]).getValues()[row] = cdist[cdist.length - 1];
        }

        public String toString() {
            if (this._wcount == 0.0) {
                return "empty, no rows";
            }
            return "auc = " + MathUtils.roundToNDigits(this.auc(), 3) + ", logloss = " + this._logloss / this._wcount;
        }
    }

    private static class BinomialMetrics
    extends MRTask<BinomialMetrics> {
        String[] domain;
        public MetricBuilderBinomial _mb;

        public BinomialMetrics(String[] domain) {
            this.domain = domain;
        }

        @Override
        public void map(Chunk[] chks) {
            this._mb = new MetricBuilderBinomial(this.domain);
            Chunk actuals = chks[1];
            Chunk weights = chks.length == 3 ? chks[2] : null;
            double[] ds = new double[3];
            float[] acts = new float[1];
            for (int i = 0; i < chks[0]._len; ++i) {
                ds[2] = chks[0].atd(i);
                ds[1] = 1.0 - ds[2];
                ds[0] = GenModel.getPrediction(ds, null, ds, Double.NaN);
                acts[0] = (float)actuals.atd(i);
                double weight = weights != null ? weights.atd(i) : 1.0;
                this._mb.perRow(ds, acts, weight, 0.0, null);
            }
        }

        @Override
        public void reduce(BinomialMetrics mrt) {
            this._mb.reduce(mrt._mb);
        }
    }
}

