/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dimensionalityreduction;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class PCA {
    private PCA() {
    }

    public static INDArray pca(INDArray A, int nDims, boolean normalize) {
        INDArray factor = PCA.pca_factor(A, nDims, normalize);
        return A.mmul(factor);
    }

    public static INDArray pca_factor(INDArray A, int nDims, boolean normalize) {
        int n;
        int m;
        if (normalize) {
            INDArray mean = A.mean(0);
            A.subiRowVector(mean);
        }
        INDArray s = Nd4j.create((m = A.rows()) < (n = A.columns()) ? m : n);
        INDArray VT = Nd4j.create(n, n, 'f');
        Nd4j.getBlasWrapper().lapack().sgesvd(A, s, null, VT);
        INDArray V = VT.transpose();
        INDArray factor = Nd4j.create(n, nDims, 'f');
        for (int i = 0; i < nDims; ++i) {
            factor.putColumn(i, V.getColumn(i));
        }
        return factor;
    }

    public static INDArray pca(INDArray A, double variance, boolean normalize) {
        INDArray factor = PCA.pca_factor(A, variance, normalize);
        return A.mmul(factor);
    }

    public static INDArray pca_factor(INDArray A, double variance, boolean normalize) {
        int n;
        int m;
        if (normalize) {
            INDArray mean = A.mean(0);
            A.subiRowVector(mean);
        }
        INDArray s = Nd4j.create((m = A.rows()) < (n = A.columns()) ? m : n);
        INDArray VT = Nd4j.create(n, n, 'f');
        Nd4j.getBlasWrapper().lapack().sgesvd(A, s, null, VT);
        for (int i = 0; i < s.length(); ++i) {
            s.putScalar(i, Math.sqrt(s.getDouble(i)) / (double)(m - 1));
        }
        double totalEigSum = s.sumNumber().doubleValue() * variance;
        int k = -1;
        double runningTotal = 0.0;
        for (int i = 0; i < s.length(); ++i) {
            if (!((runningTotal += s.getDouble(i)) >= totalEigSum)) continue;
            k = i + 1;
            break;
        }
        if (k == -1) {
            throw new RuntimeException("No reduction possible for reqd. variance - use smaller variance");
        }
        INDArray V = VT.transpose();
        INDArray factor = Nd4j.create(n, k, 'f');
        for (int i = 0; i < k; ++i) {
            factor.putColumn(i, V.getColumn(i));
        }
        return factor;
    }
}

