/*
 * Decompiled with CFR 0.152.
 */
package smile.manifold;

import java.io.Serializable;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.stat.distribution.GaussianDistribution;

public class TSNE
implements Serializable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(TSNE.class);
    public final double[][] coordinates;
    private final double eta;
    private int totalIter = 0;
    private double momentum = 0.5;
    private final double finalMomentum = 0.8;
    private final int momentumSwitchIter = 250;
    private final double minGain = 0.01;
    private final double[][] gains;
    private final double[][] P;
    private final double[][] Q;
    private double Qsum;
    private double cost;

    public TSNE(double[][] X, int d) {
        this(X, d, 20.0, 200.0, 1000);
    }

    public TSNE(double[][] X, int d, double perplexity, double eta, int iterations) {
        double[][] D;
        this.eta = eta;
        int n = X.length;
        if (X.length == X[0].length) {
            D = X;
        } else {
            D = new double[n][n];
            MathEx.pdist((Object[])X, (double[][])D, MathEx::squaredDistance);
        }
        double[][] Y = this.coordinates = new double[n][d];
        this.gains = new double[n][d];
        GaussianDistribution gaussian = new GaussianDistribution(0.0, 1.0E-4);
        for (int i = 0; i < n; ++i) {
            Arrays.fill(this.gains[i], 1.0);
            double[] Yi = Y[i];
            for (int j = 0; j < d; ++j) {
                Yi[j] = gaussian.rand();
            }
        }
        this.P = this.expd(D, perplexity, 0.001);
        this.Q = new double[n][n];
        double Psum = 2 * n;
        for (int i = 0; i < n; ++i) {
            double[] Pi = this.P[i];
            for (int j = 0; j < i; ++j) {
                double p = 12.0 * (Pi[j] + this.P[j][i]) / Psum;
                if (Double.isNaN(p) || p < 1.0E-16) {
                    p = 1.0E-16;
                }
                Pi[j] = p;
                this.P[j][i] = p;
            }
        }
        this.update(iterations);
    }

    public double cost() {
        return this.cost;
    }

    public void update(int iterations) {
        double[][] Y = this.coordinates;
        int n = Y.length;
        int d = Y[0].length;
        double[][] dY = new double[n][d];
        double[][] dC = new double[n][d];
        int iter = 1;
        while (iter <= iterations) {
            this.Qsum = this.computeQ(Y, this.Q);
            IntStream.range(0, n).parallel().forEach(i -> this.sne(i, dY[i], dC[i]));
            IntStream.range(0, n).parallel().forEach(i -> {
                double[] Yi = Y[i];
                double[] dYi = dY[i];
                double[] dCi = dC[i];
                double[] g = this.gains[i];
                for (int k = 0; k < d; ++k) {
                    dYi[k] = this.momentum * dYi[k] - this.eta * g[k] * dCi[k];
                    int n = k;
                    Yi[n] = Yi[n] + dYi[k];
                }
            });
            if (this.totalIter == 250) {
                this.momentum = 0.8;
                for (int i2 = 0; i2 < n; ++i2) {
                    double[] Pi = this.P[i2];
                    int j = 0;
                    while (j < n) {
                        int n2 = j++;
                        Pi[n2] = Pi[n2] / 12.0;
                    }
                }
            }
            if (iter % 100 == 0) {
                this.cost = this.computeCost(this.P, this.Q);
                logger.info("Error after {} iterations: {}", (Object)iter, (Object)this.cost);
            }
            ++iter;
            ++this.totalIter;
        }
        double[] colMeans = MathEx.colMeans((double[][])Y);
        IntStream.range(0, n).parallel().forEach(i -> {
            double[] Yi = Y[i];
            for (int j = 0; j < d; ++j) {
                int n = j;
                Yi[n] = Yi[n] - colMeans[j];
            }
        });
        if (iterations % 100 != 0) {
            this.cost = this.computeCost(this.P, this.Q);
            logger.info("Error after {} iterations: {}", (Object)iterations, (Object)this.cost);
        }
    }

    private void sne(int i, double[] dY, double[] dC) {
        double[][] Y = this.coordinates;
        int n = Y.length;
        int d = Y[0].length;
        double[] Yi = Y[i];
        double[] Pi = this.P[i];
        double[] Qi = this.Q[i];
        double[] g = this.gains[i];
        Arrays.fill(dC, 0.0);
        for (int j = 0; j < n; ++j) {
            if (i == j) continue;
            double[] Yj = Y[j];
            double q = Qi[j];
            double z = (Pi[j] - q / this.Qsum) * q;
            for (int k = 0; k < d; ++k) {
                int n2 = k;
                dC[n2] = dC[n2] + 4.0 * (Yi[k] - Yj[k]) * z;
            }
        }
        for (int k = 0; k < d; ++k) {
            double d2 = g[k] = Math.signum(dC[k]) != Math.signum(dY[k]) ? g[k] + 0.2 : g[k] * 0.8;
            if (!(g[k] < 0.01)) continue;
            g[k] = 0.01;
        }
    }

    private double[][] expd(double[][] D, double perplexity, double tol) {
        int n = D.length;
        double[][] P = new double[n][n];
        double[] DiSum = MathEx.rowSums((double[][])D);
        IntStream.range(0, n).parallel().forEach(i -> {
            double logU = MathEx.log2((double)perplexity);
            double[] Pi = P[i];
            double[] Di = D[i];
            double beta = Math.sqrt((double)(n - 1) / DiSum[i]);
            double betamin = 0.0;
            double betamax = Double.POSITIVE_INFINITY;
            logger.debug("initial beta[{}] = {}", (Object)i, (Object)beta);
            double Hdiff = Double.MAX_VALUE;
            for (int iter = 0; Math.abs(Hdiff) > tol && iter < 50; ++iter) {
                int j;
                double Pisum = 0.0;
                double H = 0.0;
                for (j = 0; j < n; ++j) {
                    double p;
                    double d = beta * Di[j];
                    Pi[j] = p = Math.exp(-d);
                    Pisum += p;
                    H += p * d;
                }
                Pi[i] = 0.0;
                if (Math.abs(Hdiff = (H = MathEx.log2((double)(Pisum -= 1.0)) + H / Pisum) - logU) > tol) {
                    if (Hdiff > 0.0) {
                        betamin = beta;
                        beta = Double.isInfinite(betamax) ? (beta *= 2.0) : (beta + betamax) / 2.0;
                    } else {
                        betamax = beta;
                        beta = (beta + betamin) / 2.0;
                    }
                } else {
                    j = 0;
                    while (j < n) {
                        int n2 = j++;
                        Pi[n2] = Pi[n2] / Pisum;
                    }
                }
                logger.debug("Hdiff = {}, beta[{}] = {}, H = {}, logU = {}", new Object[]{Hdiff, i, beta, H, logU});
            }
        });
        return P;
    }

    private double computeQ(double[][] Y, double[][] Q) {
        int n = Y.length;
        double[] rowSum = IntStream.range(0, n).parallel().mapToDouble(i -> {
            double[] Yi = Y[i];
            double[] Qi = Q[i];
            double sum = 0.0;
            for (int j = 0; j < n; ++j) {
                double q;
                Qi[j] = q = 1.0 / (1.0 + MathEx.squaredDistance((double[])Yi, (double[])Y[j]));
                sum += q;
            }
            return sum;
        }).toArray();
        return MathEx.sum((double[])rowSum);
    }

    private double computeCost(double[][] P, double[][] Q) {
        return 2.0 * IntStream.range(0, Q.length).parallel().mapToDouble(i -> {
            double[] Pi = P[i];
            double[] Qi = Q[i];
            double C = 0.0;
            for (int j = 0; j < i; ++j) {
                double p = Pi[j];
                double q = Qi[j] / this.Qsum;
                if (Double.isNaN(q) || q < 1.0E-16) {
                    q = 1.0E-16;
                }
                C += p * MathEx.log2((double)(p / q));
            }
            return C;
        }).sum();
    }
}

