package com.aliasi.matrix;

import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.util.Strings;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

/* loaded from: input_file:com/aliasi/matrix/SvdMatrix.class */
public class SvdMatrix extends AbstractMatrix {
    private final double[][] mRowVectors;
    private final double[][] mColumnVectors;
    private final int mOrder;
    static final double[][] EMPTY_DOUBLE_2D_ARRAY = new double[0];

    public SvdMatrix(double[][] dArr, double[][] dArr2, int i) {
        verifyDimensions("row", i, dArr);
        verifyDimensions("column", i, dArr2);
        this.mRowVectors = dArr;
        this.mColumnVectors = dArr2;
        this.mOrder = i;
    }

    public SvdMatrix(double[][] dArr, double[][] dArr2, double[] dArr3) {
        this.mOrder = dArr3.length;
        verifyDimensions("row", this.mOrder, dArr);
        verifyDimensions("column", this.mOrder, dArr2);
        this.mRowVectors = new double[dArr.length][this.mOrder];
        this.mColumnVectors = new double[dArr2.length][this.mOrder];
        double[] dArr4 = new double[dArr3.length];
        for (int i = 0; i < dArr4.length; i++) {
            dArr4[i] = Math.sqrt(dArr3[i]);
        }
        scale(this.mRowVectors, dArr, dArr4);
        scale(this.mColumnVectors, dArr2, dArr4);
    }

    @Override // com.aliasi.matrix.AbstractMatrix, com.aliasi.matrix.Matrix
    public int numRows() {
        return this.mRowVectors.length;
    }

    @Override // com.aliasi.matrix.AbstractMatrix, com.aliasi.matrix.Matrix
    public int numColumns() {
        return this.mColumnVectors.length;
    }

    public int order() {
        return this.mRowVectors[0].length;
    }

    @Override // com.aliasi.matrix.AbstractMatrix, com.aliasi.matrix.Matrix
    public double value(int i, int i2) {
        double[] dArr = this.mRowVectors[i];
        double[] dArr2 = this.mColumnVectors[i2];
        double d = 0.0d;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            d += dArr[i3] * dArr2[i3];
        }
        return d;
    }

    public double[] singularValues() {
        double[] dArr = new double[this.mOrder];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = singularValue(i);
        }
        return dArr;
    }

    public double singularValue(int i) {
        if (i >= this.mOrder) {
            throw new IllegalArgumentException("Maximum order=" + (this.mOrder - 1) + " found order=" + i);
        }
        return columnLength(this.mRowVectors, i) * columnLength(this.mColumnVectors, i);
    }

    public double[][] leftSingularVectors() {
        return normalizeColumns(this.mRowVectors);
    }

    public double[][] rightSingularVectors() {
        return normalizeColumns(this.mColumnVectors);
    }

    /* JADX WARN: Type inference failed for: r0v15, types: [int[], int[][]] */
    public static SvdMatrix svd(double[][] dArr, int i, double d, double d2, double d3, double d4, Reporter reporter, double d5, int i2, int i3) {
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        int length = dArr.length;
        int length2 = dArr[0].length;
        reporter.info("Calculating SVD");
        reporter.info("#Rows=" + length + " #Cols=" + length2);
        for (int i4 = 1; i4 < length; i4++) {
            if (dArr[i4].length != length2) {
                String str = "All rows must be of same length. Found row[0].length=" + length2 + " row[" + i4 + "]=" + dArr[i4].length;
                reporter.fatal(str);
                throw new IllegalArgumentException(str);
            }
        }
        int[] iArr = new int[length2];
        for (int i5 = 0; i5 < length2; i5++) {
            iArr[i5] = i5;
        }
        ?? r0 = new int[length];
        for (int i6 = 0; i6 < length; i6++) {
            r0[i6] = iArr;
        }
        return partialSvd(r0, dArr, i, d, d2, d3, d4, reporter, d5, i2, i3);
    }

    public static SvdMatrix partialSvd(int[][] iArr, double[][] dArr, int i, double d, double d2, double d3, double d4, Reporter reporter, double d5, int i2, int i3) {
        return partialSvd(iArr, dArr, i, d, d2, d3, d4, new Random(), reporter, d5, i2, i3);
    }

    /* JADX WARN: Multi-variable type inference failed */
    static SvdMatrix partialSvd(int[][] iArr, double[][] dArr, int i, double d, double d2, double d3, double d4, Random random, Reporter reporter, double d5, int i2, int i3) {
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        reporter.info("Start");
        if (i < 1) {
            String str = "Max order must be >= 1. Found maxOrder=" + i;
            reporter.fatal(str);
            throw new IllegalArgumentException(str);
        }
        if (d5 < 0.0d || notFinite(d5)) {
            String str2 = "Min improvement must be finite and non-negative. Found minImprovement=" + d5;
            reporter.fatal(str2);
            throw new IllegalArgumentException(str2);
        }
        if (i2 <= 0 || i3 < i2) {
            String str3 = "Min epochs must be non-negative and less than or equal to max epochs. found minEpochs=" + i2 + " maxEpochs=" + i3;
            reporter.fatal(str3);
            throw new IllegalArgumentException(str3);
        }
        if (notFinite(d) || d == 0.0d) {
            String str4 = "Feature inits must be finite and non-zero. Found featureInit=" + d;
            reporter.fatal(str4);
            throw new IllegalArgumentException(str4);
        }
        if (notFinite(d2) || d2 < 0.0d) {
            String str5 = "Initial learning rate must be finite and non-negative. Found initialLearningRate=" + d2;
            reporter.fatal(str5);
            throw new IllegalArgumentException(str5);
        }
        if (notFinite(d4) || d4 < 0.0d) {
            String str6 = "Regularization must be finite and non-negative. Found regularization=" + d4;
            reporter.fatal(str6);
            throw new IllegalArgumentException(str6);
        }
        for (int i4 = 0; i4 < iArr.length; i4++) {
            if (iArr == null) {
                reporter.fatal("ColumnIds must not be null.");
                throw new IllegalArgumentException("ColumnIds must not be null.");
            }
            if (dArr == null) {
                reporter.fatal("Values must not be null");
                throw new IllegalArgumentException("Values must not be null");
            }
            if (iArr[i4] == null) {
                String str7 = "All column Ids must be non-null. Found null in row=" + i4;
                reporter.fatal(str7);
                throw new IllegalArgumentException(str7);
            }
            if (dArr[i4] == null) {
                String str8 = "All values must be non-null. Found null row=" + i4;
                reporter.fatal(str8);
                throw new IllegalArgumentException(str8);
            }
            if (iArr[i4].length != dArr[i4].length) {
                String str9 = "column Ids and values must be same length. For row=" + i4 + " Found columnIds[row].length=" + iArr[i4].length + " Found values[row].length=" + dArr[i4].length;
                reporter.fatal(str9);
                throw new IllegalArgumentException(str9);
            }
            for (int i5 = 0; i5 < iArr[i4].length; i5++) {
                if (iArr[i4][i5] < 0) {
                    String str10 = "Column ids must be non-negative. Found columnIds[" + i4 + "][" + i5 + "]=" + iArr[i4][i5];
                    reporter.fatal(str10);
                    throw new IllegalArgumentException(str10);
                }
                if (i5 > 0 && iArr[i4][i5 - 1] >= iArr[i4][i5]) {
                    String str11 = "All column Ids must be same length. At row=" + i4 + " Mismatch at rows " + i5 + " and " + (i5 - 1);
                    reporter.fatal(str11);
                    throw new IllegalArgumentException(str11);
                }
            }
        }
        if (d3 < 0.0d || notFinite(d3)) {
            reporter.fatal("Annealing rate must be finite and non-negative. Found rate=" + d3);
            throw new IllegalArgumentException("14");
        }
        int length = iArr.length;
        int i6 = 0;
        for (double[] dArr2 : dArr) {
            i6 += dArr2.length;
        }
        int i7 = 0;
        for (int[] iArr2 : iArr) {
            for (int i8 = 0; i8 < iArr2.length; i8++) {
                if (iArr2[i8] > i7) {
                    i7 = iArr2[i8];
                }
            }
        }
        int i9 = i7 + 1;
        int min = Math.min(i, Math.min(length, i9));
        double[] dArr3 = new double[dArr.length];
        for (int i10 = 0; i10 < length; i10++) {
            dArr3[i10] = new double[dArr[i10].length];
            Arrays.fill(dArr3[i10], 0.0d);
        }
        ArrayList arrayList = new ArrayList(min);
        ArrayList arrayList2 = new ArrayList(min);
        for (int i11 = 0; i11 < min; i11++) {
            reporter.info("  Factor=" + i11);
            double[] initArray = initArray(length, d, random);
            double[] initArray2 = initArray(i9, d, random);
            double d6 = Double.POSITIVE_INFINITY;
            int i12 = 0;
            while (true) {
                if (i12 >= i3) {
                    break;
                }
                double d7 = d2 / (1.0d + (i12 / d3));
                double d8 = 0.0d;
                for (int i13 = 0; i13 < length; i13++) {
                    int[] iArr3 = iArr[i13];
                    double[] dArr4 = dArr[i13];
                    Object[] objArr = dArr3[i13];
                    for (int i14 = 0; i14 < iArr3.length; i14++) {
                        int i15 = iArr3[i14];
                        double predict = dArr4[i14] - predict(i13, i15, initArray, initArray2, objArr[i14]);
                        d8 += predict * predict;
                        double d9 = initArray[i13];
                        double d10 = initArray2[i15];
                        int i16 = i13;
                        initArray[i16] = initArray[i16] + (d7 * ((predict * d10) - (d4 * d9)));
                        initArray2[i15] = initArray2[i15] + (d7 * ((predict * d9) - (d4 * d10)));
                    }
                }
                double sqrt = Math.sqrt(d8 / i6);
                reporter.info("    epoch=" + i12 + " rmse=" + sqrt);
                if (i12 >= i2 && relativeDifference(sqrt, d6) < d5) {
                    reporter.info("Converged in epoch=" + i12 + " rmse=" + sqrt + " relDiff=" + relativeDifference(sqrt, d6));
                    break;
                }
                d6 = sqrt;
                i12++;
            }
            reporter.info("Order=" + i11 + " RMSE=" + d6);
            arrayList.add(initArray);
            arrayList2.add(initArray2);
            for (int i17 = 0; i17 < dArr3.length; i17++) {
                double[] dArr5 = dArr3[i17];
                for (int i18 = 0; i18 < dArr5.length; i18++) {
                    dArr5[i18] = predict(i17, iArr[i17][i18], initArray, initArray2, dArr5[i18]);
                }
            }
        }
        return new SvdMatrix(transpose((double[][]) arrayList.toArray(EMPTY_DOUBLE_2D_ARRAY)), transpose((double[][]) arrayList2.toArray(EMPTY_DOUBLE_2D_ARRAY)), min);
    }

    static double relativeDifference(double d, double d2) {
        return Math.abs(d - d2) / (Math.abs(d) + Math.abs(d2));
    }

    static double[][] transpose(double[][] dArr) {
        double[][] dArr2 = new double[dArr[0].length][dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                dArr2[i2][i] = dArr[i][i2];
            }
        }
        return dArr2;
    }

    static double predict(int i, int i2, double[] dArr, double[] dArr2, double d) {
        return d + (dArr[i] * dArr2[i2]);
    }

    static double[] initArray(int i, double d, Random random) {
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = random.nextGaussian() * d;
        }
        return dArr;
    }

    static boolean notFinite(double d) {
        return Double.isNaN(d) || Double.isInfinite(d);
    }

    static double columnLength(double[][] dArr, int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            d += dArr[i2][i] * dArr[i2][i];
        }
        return Math.sqrt(d);
    }

    static void scale(double[][] dArr, double[][] dArr2, double[] dArr3) {
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                dArr[i][i2] = dArr2[i][i2] * dArr3[i2];
            }
        }
    }

    static void verifyDimensions(String str, int i, double[][] dArr) {
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2].length != i) {
                throw new IllegalArgumentException("All vectors must have length equal to order. order=" + i + Strings.SINGLE_SPACE_STRING + str + "Vectors[" + i2 + "].length=" + dArr[i2].length);
            }
        }
    }

    static double[][] normalizeColumns(double[][] dArr) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        double[][] dArr2 = new double[length][length2];
        for (int i = 0; i < length2; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < length; i2++) {
                double d2 = dArr[i2][i];
                dArr2[i2][i] = d2;
                d += d2 * d2;
            }
            double sqrt = Math.sqrt(d);
            for (int i3 = 0; i3 < length; i3++) {
                double[] dArr3 = dArr2[i3];
                int i4 = i;
                dArr3[i4] = dArr3[i4] / sqrt;
            }
        }
        return dArr2;
    }
}
