/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.evaluation;

import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

public class EvaluationUtils {
    public static double precision(long tpCount, long fpCount, double edgeCase) {
        if (tpCount == 0L && fpCount == 0L) {
            return edgeCase;
        }
        return (double)tpCount / (double)(tpCount + fpCount);
    }

    public static double recall(long tpCount, long fnCount, double edgeCase) {
        if (tpCount == 0L && fnCount == 0L) {
            return edgeCase;
        }
        return (double)tpCount / (double)(tpCount + fnCount);
    }

    public static double falsePositiveRate(long fpCount, long tnCount, double edgeCase) {
        if (fpCount == 0L && tnCount == 0L) {
            return edgeCase;
        }
        return (double)fpCount / (double)(fpCount + tnCount);
    }

    public static double falseNegativeRate(long fnCount, long tpCount, double edgeCase) {
        if (fnCount == 0L && tpCount == 0L) {
            return edgeCase;
        }
        return (double)fnCount / (double)(fnCount + tpCount);
    }

    public static double fBeta(double beta, long tp, long fp, long fn) {
        double prec = (double)tp / ((double)tp + (double)fp);
        double recall = (double)tp / ((double)tp + (double)fn);
        return EvaluationUtils.fBeta(beta, prec, recall);
    }

    public static double fBeta(double beta, double precision, double recall) {
        if (precision == 0.0 || recall == 0.0) {
            return 0.0;
        }
        double numerator = (1.0 + beta * beta) * precision * recall;
        double denominator = beta * beta * precision + recall;
        return numerator / denominator;
    }

    public static double gMeasure(double precision, double recall) {
        return Math.sqrt(precision * recall);
    }

    public static double matthewsCorrelation(long tp, long fp, long fn, long tn) {
        double numerator = (double)tp * (double)tn - (double)fp * (double)fn;
        double denominator = Math.sqrt(((double)tp + (double)fp) * (double)(tp + fn) * (double)(tn + fp) * (double)(tn + fn));
        return numerator / denominator;
    }

    public static INDArray reshapeTimeSeriesTo2d(INDArray labels) {
        INDArray labels2d;
        long[] labelsShape = labels.shape();
        if (labelsShape[0] == 1L) {
            labels2d = labels.tensorAlongDimension(0, 1, 2).permutei(1, 0);
        } else if (labelsShape[2] == 1L) {
            labels2d = labels.tensorAlongDimension(0, 1, 0);
        } else {
            labels2d = labels.permute(0, 2, 1);
            labels2d = labels2d.reshape('f', labelsShape[0] * labelsShape[2], labelsShape[1]);
        }
        return labels2d;
    }

    public static Pair<INDArray, INDArray> extractNonMaskedTimeSteps(INDArray labels, INDArray predicted, INDArray outputMask) {
        if (labels.rank() != 3 || predicted.rank() != 3) {
            throw new IllegalArgumentException("Invalid data: expect rank 3 arrays. Got arrays with shapes labels=" + Arrays.toString(labels.shape()) + ", predictions=" + Arrays.toString(predicted.shape()));
        }
        labels = labels.dup('f');
        predicted = predicted.dup('f');
        INDArray labels2d = EvaluationUtils.reshapeTimeSeriesTo2d(labels);
        INDArray predicted2d = EvaluationUtils.reshapeTimeSeriesTo2d(predicted);
        if (outputMask == null) {
            return new Pair((Object)labels2d, (Object)predicted2d);
        }
        INDArray oneDMask = EvaluationUtils.reshapeTimeSeriesMaskToVector(outputMask);
        float[] f = oneDMask.dup().data().asFloat();
        int[] rowsToPull = new int[f.length];
        int usedCount = 0;
        for (int i = 0; i < f.length; ++i) {
            if (f[i] != 1.0f) continue;
            rowsToPull[usedCount++] = i;
        }
        if (usedCount == 0) {
            return null;
        }
        rowsToPull = Arrays.copyOfRange(rowsToPull, 0, usedCount);
        labels2d = Nd4j.pullRows(labels2d, 1, rowsToPull);
        predicted2d = Nd4j.pullRows(predicted2d, 1, rowsToPull);
        return new Pair((Object)labels2d, (Object)predicted2d);
    }

    public static INDArray reshapeTimeSeriesMaskToVector(INDArray timeSeriesMask) {
        if (timeSeriesMask.rank() != 2) {
            throw new IllegalArgumentException("Cannot reshape mask: rank is not 2");
        }
        if (timeSeriesMask.ordering() != 'f') {
            timeSeriesMask = timeSeriesMask.dup('f');
        }
        return timeSeriesMask.reshape('f', timeSeriesMask.length(), 1L);
    }
}

