package com.aliasi.stats;

import com.aliasi.util.Math;
import java.util.Arrays;
import java.util.Random;

/* loaded from: input_file:com/aliasi/stats/Statistics.class */
public class Statistics {
    private Statistics() {
    }

    public static double klDivergenceDirichlet(double[] dArr, double[] dArr2) {
        verifyDivergenceDirichletArgs(dArr, dArr2);
        double sum = sum(dArr);
        double logGamma = logGamma(sum) - logGamma(sum(dArr2));
        double digamma = Math.digamma(sum);
        for (int i = 0; i < dArr.length; i++) {
            logGamma += (logGamma(dArr2[i]) - logGamma(dArr[i])) + ((dArr[i] - dArr2[i]) * (Math.digamma(dArr[i]) - digamma));
        }
        return logGamma;
    }

    static void verifyDivergenceDirichletArgs(double[] dArr, double[] dArr2) {
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException("Parameter arrays must be the same length. Found xs.length=" + dArr.length + " ys.length=" + dArr2.length);
        }
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] <= 0.0d || Double.isInfinite(dArr[i]) || Double.isNaN(dArr[i])) {
                throw new IllegalArgumentException("All parameters must be positive and finite. Found xs[" + i + "]=" + dArr[i]);
            }
        }
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            if (dArr2[i2] <= 0.0d || Double.isInfinite(dArr2[i2]) || Double.isNaN(dArr2[i2])) {
                throw new IllegalArgumentException("All parameters must be positive and finite. Found ys[" + i2 + "]=" + dArr2[i2]);
            }
        }
    }

    public static double symmetrizedKlDivergenceDirichlet(double[] dArr, double[] dArr2) {
        return (klDivergenceDirichlet(dArr, dArr2) + klDivergenceDirichlet(dArr2, dArr)) / 2.0d;
    }

    static double logGamma(double d) {
        return Math.log2Gamma(d) / Math.log2(2.718281828459045d);
    }

    static double sum(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        return d;
    }

    public static double klDivergence(double[] dArr, double[] dArr2) {
        verifyDivergenceArgs(dArr, dArr2);
        double d = 0.0d;
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            if (dArr[i] > 0.0d && dArr[i] != dArr2[i]) {
                d += dArr[i] * Math.log2(dArr[i] / dArr2[i]);
            }
        }
        return d;
    }

    static void verifyDivergenceArgs(double[] dArr, double[] dArr2) {
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException("Input distributions must have same length. Found p.length=" + dArr.length + " q.length=" + dArr2.length);
        }
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            if (dArr[i] < 0.0d || dArr[i] > 1.0d || Double.isNaN(dArr[i]) || Double.isInfinite(dArr[i])) {
                throw new IllegalArgumentException("p[i] must be between 0.0 and 1.0 inclusive. found p[" + i + "]=" + dArr[i]);
            }
            if (dArr2[i] < 0.0d || dArr2[i] > 1.0d || Double.isNaN(dArr2[i]) || Double.isInfinite(dArr2[i])) {
                throw new IllegalArgumentException("q[i] must be between 0.0 and 1.0 inclusive. found q[" + i + "] =" + dArr2[i]);
            }
        }
    }

    public static double symmetrizedKlDivergence(double[] dArr, double[] dArr2) {
        verifyDivergenceArgs(dArr, dArr2);
        return (klDivergence(dArr, dArr2) + klDivergence(dArr2, dArr)) / 2.0d;
    }

    public static double jsDivergence(double[] dArr, double[] dArr2) {
        verifyDivergenceArgs(dArr, dArr2);
        double[] dArr3 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr3[i] = (dArr[i] + dArr2[i]) / 2.0d;
        }
        return (klDivergence(dArr, dArr3) + klDivergence(dArr2, dArr3)) / 2.0d;
    }

    public static int[] permutation(int i) {
        return permutation(i, new Random());
    }

    public static int[] permutation(int i, Random random) {
        int[] iArr = new int[i];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr[i2] = i2;
        }
        int length = iArr.length;
        while (true) {
            length--;
            if (length <= 0) {
                return iArr;
            }
            int nextInt = random.nextInt(length);
            int i3 = iArr[nextInt];
            iArr[nextInt] = iArr[length];
            iArr[length] = i3;
        }
    }

    public static double chiSquaredIndependence(double d, double d2, double d3, double d4) {
        assertNonNegative("both", d);
        assertNonNegative("oneOnly", d2);
        assertNonNegative("twoOnly", d3);
        assertNonNegative("neither", d4);
        double d5 = d + d2 + d3 + d4;
        double d6 = (d + d2) / d5;
        double d7 = (d + d3) / d5;
        return csTerm(d, d5 * d6 * d7) + csTerm(d2, d5 * d6 * (1.0d - d7)) + csTerm(d3, d5 * (1.0d - d6) * d7) + csTerm(d4, d5 * (1.0d - d6) * (1.0d - d7));
    }

    public static double[] linearRegression(double[] dArr, double[] dArr2) {
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException("Require parallel arrays of x and y values. Found xs.length=" + dArr.length + " ys.length=" + dArr2.length);
        }
        if (dArr.length < 2) {
            throw new IllegalArgumentException("Require arrays of length >= 2. Found xs.length=" + dArr.length);
        }
        double length = dArr.length;
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            double d5 = dArr[i];
            double d6 = dArr2[i];
            d += d5;
            d2 += d6;
            d4 += d5 * d5;
            d3 += d5 * d6;
        }
        double d7 = (length * d4) - (d * d);
        if (d7 == 0.0d) {
            throw new IllegalArgumentException("Ill formed input. Denominator for beta1 is zero. Most likely cause is fewer than 2 distinct inputs.");
        }
        double d8 = ((length * d3) - (d * d2)) / d7;
        return new double[]{(d2 - (d8 * d)) / length, d8};
    }

    public static double[] logisticRegression(double[] dArr, double[] dArr2, double d) {
        if (d <= 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("Require finite max value > 0. Found maxValue=" + d);
        }
        double[] dArr3 = new double[dArr2.length];
        for (int i = 0; i < dArr2.length; i++) {
            dArr3[i] = Math.log((d - dArr2[i]) / dArr2[i]);
        }
        return linearRegression(dArr, dArr3);
    }

    public static double chiSquaredIndependence(double[][] dArr) {
        int length = dArr.length;
        if (length < 2) {
            throw new IllegalArgumentException("Require at least two rows. Found numRows=" + length);
        }
        int length2 = dArr[0].length;
        if (length2 < 2) {
            throw new IllegalArgumentException("Require at least two cols. Found numCols=" + length2);
        }
        double[] dArr2 = new double[length];
        Arrays.fill(dArr2, 0.0d);
        double[] dArr3 = new double[length2];
        Arrays.fill(dArr3, 0.0d);
        double d = 0.0d;
        for (int i = 0; i < length; i++) {
            if (dArr[i].length != length2) {
                throw new IllegalArgumentException("Matrix must be rectangular.Row 0 length=" + length2 + "Row " + i + " length=" + dArr[i].length);
            }
            for (int i2 = 0; i2 < length2; i2++) {
                double d2 = dArr[i][i2];
                if (Double.isInfinite(d2) || d2 < 0.0d || Double.isNaN(d2)) {
                    throw new IllegalArgumentException("Values must be finite non-negative. Found matrix[" + i + "][" + i2 + "]=" + d2);
                }
                int i3 = i;
                dArr2[i3] = dArr2[i3] + d2;
                int i4 = i2;
                dArr3[i4] = dArr3[i4] + d2;
                d += d2;
            }
        }
        double d3 = 0.0d;
        for (int i5 = 0; i5 < length; i5++) {
            for (int i6 = 0; i6 < length2; i6++) {
                d3 += csTerm(dArr[i5][i6], (dArr2[i5] * dArr3[i6]) / d);
            }
        }
        return d3;
    }

    public static double[] normalize(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] < 0.0d || Double.isInfinite(dArr[i]) || Double.isNaN(dArr[i])) {
                throw new IllegalArgumentException("Probabilities must be finite non-negative. Found probabilityRatios[" + i + "]=" + dArr[i]);
            }
        }
        double sum = Math.sum(dArr);
        if (sum <= 0.0d) {
            throw new IllegalArgumentException("Ratios must sum to number greater than zero. Found sum=" + sum);
        }
        double[] dArr2 = new double[dArr.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr2[i2] = dArr[i2] / sum;
        }
        return dArr2;
    }

    public static double kappa(double d, double d2) {
        return (d - d2) / (1.0d - d2);
    }

    public static double mean(double[] dArr) {
        return Math.sum(dArr) / dArr.length;
    }

    public static double variance(double[] dArr) {
        return variance(dArr, mean(dArr));
    }

    public static double standardDeviation(double[] dArr) {
        return Math.sqrt(variance(dArr));
    }

    public static double correlation(double[] dArr, double[] dArr2) {
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException("xs and ys must be the same length. Found xs.length=" + dArr.length + " ys.length=" + dArr2.length);
        }
        double mean = mean(dArr);
        double mean2 = mean(dArr2);
        double sumSquareDiffs = sumSquareDiffs(dArr, mean);
        double sumSquareDiffs2 = sumSquareDiffs(dArr2, mean2);
        double sumSquareDiffs3 = sumSquareDiffs(dArr, dArr2, mean, mean2);
        return Math.sqrt((sumSquareDiffs3 * sumSquareDiffs3) / (sumSquareDiffs * sumSquareDiffs2));
    }

    public static int sample(double[] dArr, Random random) {
        int i = 0;
        int length = dArr.length - 1;
        double nextDouble = random.nextDouble() * dArr[length];
        while (i < length) {
            int i2 = (length + i) / 2;
            if (nextDouble > dArr[i2]) {
                i = i2 + 1;
            } else {
                if (length == i2) {
                    return nextDouble > dArr[i] ? i2 : i;
                }
                length = i2;
            }
        }
        return i;
    }

    public static double dirichletLog2Prob(double d, double[] dArr) {
        verifyAlpha(d);
        verifyDistro(dArr);
        int length = dArr.length;
        double log2Gamma = Math.log2Gamma(length * d) - (length * Math.log2Gamma(d));
        double d2 = d - 1.0d;
        for (double d3 : dArr) {
            log2Gamma += d2 * Math.log2(d3);
        }
        return log2Gamma;
    }

    public static double dirichletLog2Prob(double[] dArr, double[] dArr2) {
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException("Dirichlet prior alphas and distribution xs must be the same length. Found alphas.length=" + dArr.length + " xs.length=" + dArr2.length);
        }
        for (double d : dArr) {
            verifyAlpha(d);
        }
        verifyDistro(dArr2);
        int length = dArr2.length;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d3 += dArr[i];
            d2 -= Math.log2Gamma(dArr[i]);
        }
        double log2Gamma = d2 + Math.log2Gamma(d3);
        for (int i2 = 0; i2 < length; i2++) {
            log2Gamma += (dArr[i2] - 1.0d) * Math.log2(dArr2[i2]);
        }
        return log2Gamma;
    }

    static void verifyAlpha(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d <= 0.0d) {
            throw new IllegalArgumentException("Concentration parameter must be positive and finite. Found alpha=" + d);
        }
    }

    static void verifyDistro(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] < 0.0d || dArr[i] > 1.0d || Double.isNaN(dArr[i]) || Double.isInfinite(dArr[i])) {
                throw new IllegalArgumentException("All xs must be betwee 0.0 and 1.0 inclusive. Found xs[" + i + "]=" + dArr[i]);
            }
        }
    }

    static double sumSquareDiffs(double[] dArr, double d) {
        double d2 = 0.0d;
        for (double d3 : dArr) {
            double d4 = d3 - d;
            d2 += d4 * d4;
        }
        return d2;
    }

    static double sumSquareDiffs(double[] dArr, double[] dArr2, double d, double d2) {
        double d3 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d3 += (dArr[i] - d) * (dArr2[i] - d2);
        }
        return d3;
    }

    static double variance(double[] dArr, double d) {
        return sumSquareDiffs(dArr, d) / dArr.length;
    }

    static void assertNonNegative(String str, double d) {
        if (Double.isInfinite(d) || Double.isNaN(d) || d < 0.0d) {
            throw new IllegalArgumentException("Require finite non-negative value. Found " + str + " =" + d);
        }
    }

    private static double csTerm(double d, double d2) {
        double d3 = d - d2;
        return (d3 * d3) / d2;
    }
}
