/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.plot;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.sptree.SpTree;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.plot.Tsne;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.learning.legacy.AdaGrad;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BarnesHutTsne
implements Model {
    private static final Logger log = LoggerFactory.getLogger(BarnesHutTsne.class);
    public static final String workspaceCache = "LOOP_CACHE";
    public static final String workspaceExternal = "LOOP_EXTERNAL";
    protected int maxIter = 1000;
    protected double realMin = Nd4j.EPS_THRESHOLD;
    protected double initialMomentum = 0.5;
    protected double finalMomentum = 0.8;
    protected double minGain = 0.01;
    protected double momentum = this.initialMomentum;
    protected int switchMomentumIteration = 250;
    protected boolean normalize = true;
    protected boolean usePca = false;
    protected int stopLyingIteration = 250;
    protected double tolerance = 1.0E-5;
    protected double learningRate = 500.0;
    protected AdaGrad adaGrad;
    protected boolean useAdaGrad = true;
    protected double perplexity = 30.0;
    protected INDArray Y;
    private int N;
    private double theta;
    private INDArray rows;
    private INDArray cols;
    private INDArray vals;
    private String simiarlityFunction = "cosinesimilarity";
    private boolean invert = true;
    private INDArray x;
    private int numDimensions = 0;
    public static final String Y_GRAD = "yIncs";
    private SpTree tree;
    private INDArray gains;
    private INDArray yIncs;
    private int vpTreeWorkers;
    protected transient TrainingListener trainingListener;
    protected WorkspaceMode workspaceMode;
    private Initializer initializer;
    protected static final WorkspaceConfiguration workspaceConfigurationExternal = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.3).policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    protected WorkspaceConfiguration workspaceConfigurationFeedForward = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT).policyLearning(LearningPolicy.OVER_TIME).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    public static final WorkspaceConfiguration workspaceConfigurationCache = WorkspaceConfiguration.builder().overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT).cyclesBeforeInitialization(3).policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.OVER_TIME).build();

    public BarnesHutTsne(int numDimensions, String simiarlityFunction, double theta, boolean invert, int maxIter, double realMin, double initialMomentum, double finalMomentum, double momentum, int switchMomentumIteration, boolean normalize, int stopLyingIteration, double tolerance, double learningRate, boolean useAdaGrad, double perplexity, TrainingListener TrainingListener2, double minGain, int vpTreeWorkers) {
        this(numDimensions, simiarlityFunction, theta, invert, maxIter, realMin, initialMomentum, finalMomentum, momentum, switchMomentumIteration, normalize, stopLyingIteration, tolerance, learningRate, useAdaGrad, perplexity, TrainingListener2, minGain, vpTreeWorkers, WorkspaceMode.NONE, null);
    }

    public BarnesHutTsne(int numDimensions, String simiarlityFunction, double theta, boolean invert, int maxIter, double realMin, double initialMomentum, double finalMomentum, double momentum, int switchMomentumIteration, boolean normalize, int stopLyingIteration, double tolerance, double learningRate, boolean useAdaGrad, double perplexity, TrainingListener TrainingListener2, double minGain, int vpTreeWorkers, WorkspaceMode workspaceMode, INDArray staticInput) {
        this.maxIter = maxIter;
        this.realMin = realMin;
        this.initialMomentum = initialMomentum;
        this.finalMomentum = finalMomentum;
        this.momentum = momentum;
        this.normalize = normalize;
        this.useAdaGrad = useAdaGrad;
        this.stopLyingIteration = stopLyingIteration;
        this.learningRate = learningRate;
        this.switchMomentumIteration = switchMomentumIteration;
        this.tolerance = tolerance;
        this.perplexity = perplexity;
        this.minGain = minGain;
        this.numDimensions = numDimensions;
        this.simiarlityFunction = simiarlityFunction;
        this.theta = theta;
        this.trainingListener = TrainingListener2;
        this.invert = invert;
        this.vpTreeWorkers = vpTreeWorkers;
        this.workspaceMode = workspaceMode;
        if (this.workspaceMode == null) {
            this.workspaceMode = WorkspaceMode.NONE;
        }
        this.initializer = staticInput != null ? new Initializer(staticInput) : new Initializer();
    }

    public String getSimiarlityFunction() {
        return this.simiarlityFunction;
    }

    public void setSimiarlityFunction(String simiarlityFunction) {
        this.simiarlityFunction = simiarlityFunction;
    }

    public boolean isInvert() {
        return this.invert;
    }

    public void setInvert(boolean invert) {
        this.invert = invert;
    }

    public double getTheta() {
        return this.theta;
    }

    public double getPerplexity() {
        return this.perplexity;
    }

    public int getNumDimensions() {
        return this.numDimensions;
    }

    public void setNumDimensions(int numDimensions) {
        this.numDimensions = numDimensions;
    }

    public INDArray computeGaussianPerplexity(INDArray d, double perplexity) {
        this.N = d.rows();
        int k = (int)(3.0 * perplexity);
        if ((double)(this.N - 1) < 3.0 * perplexity) {
            throw new IllegalStateException("Perplexity " + perplexity + "is too large for number of samples " + this.N);
        }
        this.rows = Nd4j.zeros((DataType)DataType.INT, (long[])new long[]{1L, this.N + 1});
        this.cols = Nd4j.zeros((DataType)DataType.INT, (long[])new long[]{1L, this.N * k});
        this.vals = Nd4j.zeros((DataType)d.dataType(), (long[])new long[]{this.N * k});
        for (int n = 0; n < this.N; ++n) {
            this.rows.putScalar((long)(n + 1), this.rows.getDouble((long)n) + (double)k);
        }
        double enthropy = Math.log(perplexity);
        VPTree tree = new VPTree(d, this.simiarlityFunction, this.vpTreeWorkers, this.invert);
        log.info("Calculating probabilities of data similarities...");
        for (int i = 0; i < this.N; ++i) {
            if (i % 500 == 0) {
                log.info("Handled " + i + " records");
            }
            double betaMin = -1.7976931348623157E308;
            double betaMax = Double.MAX_VALUE;
            ArrayList results = new ArrayList();
            ArrayList distances = new ArrayList();
            tree.search(d.getRow((long)i), k + 1, results, distances, false, true);
            double betas = 1.0;
            if (results.size() == 0) {
                throw new IllegalStateException("Search returned no values for vector " + i + " - similarity \"" + this.simiarlityFunction + "\" may not be defined (for example, vector is all zeros with cosine similarity)");
            }
            Double[] dists = new Double[distances.size()];
            distances.toArray(dists);
            INDArray cArr = Nd4j.createFromArray((Double[])dists).castTo(d.dataType());
            INDArray currP = null;
            int tries = 0;
            boolean found = false;
            while (!found && tries < 200) {
                Pair<INDArray, Double> pair = this.computeGaussianKernel(cArr, betas, k);
                currP = (INDArray)pair.getFirst();
                double hDiff = (Double)pair.getSecond() - enthropy;
                if (hDiff < this.tolerance && -hDiff < this.tolerance) {
                    found = true;
                    continue;
                }
                if (hDiff > 0.0) {
                    betaMin = betas;
                    betas = betaMax == Double.MAX_VALUE || betaMax == -1.7976931348623157E308 ? (betas *= 2.0) : (betas + betaMax) / 2.0;
                } else {
                    betaMax = betas;
                    betas = betaMin == -1.7976931348623157E308 || betaMin == Double.MAX_VALUE ? (betas /= 2.0) : (betas + betaMin) / 2.0;
                }
                ++tries;
            }
            currP.divi((Number)(currP.sumNumber().doubleValue() + Double.MIN_VALUE));
            INDArray indices = Nd4j.create((int[])new int[]{1, k + 1});
            for (int j = 0; (long)j < indices.length() && j < results.size(); ++j) {
                indices.putScalar((long)j, ((DataPoint)results.get(j)).getIndex());
            }
            for (int l = 0; l < k; ++l) {
                this.cols.putScalar((long)(this.rows.getInt(new int[]{i}) + l), indices.getDouble((long)(l + 1)));
                this.vals.putScalar((long)(this.rows.getInt(new int[]{i}) + l), currP.getDouble((long)l));
            }
        }
        return this.vals;
    }

    public INDArray input() {
        return this.x;
    }

    public ConvexOptimizer getOptimizer() {
        return null;
    }

    public INDArray getParam(String param) {
        return null;
    }

    public void addListeners(TrainingListener ... listener) {
    }

    public Map<String, INDArray> paramTable() {
        return null;
    }

    public Map<String, INDArray> paramTable(boolean backprapParamsOnly) {
        return null;
    }

    public void setParamTable(Map<String, INDArray> paramTable) {
    }

    public void setParam(String key, INDArray val) {
    }

    public void clear() {
    }

    public void applyConstraints(int iteration, int epoch) {
    }

    protected Pair<Double, INDArray> gradient(INDArray p) {
        throw new UnsupportedOperationException();
    }

    public SymResult symmetrized(INDArray rowP, INDArray colP, INDArray valP) {
        int n;
        INDArray rowCounts = Nd4j.create((DataType)DataType.INT, (long[])new long[]{this.N});
        for (int n2 = 0; n2 < this.N; ++n2) {
            int begin = rowP.getInt(new int[]{n2});
            int end = rowP.getInt(new int[]{n2 + 1});
            for (int i = begin; i < end; ++i) {
                boolean present = false;
                for (int m = rowP.getInt(new int[]{colP.getInt(new int[]{i})}); m < rowP.getInt(new int[]{colP.getInt(new int[]{i}) + 1}); ++m) {
                    if (colP.getInt(new int[]{m}) != n2) continue;
                    present = true;
                }
                if (present) {
                    rowCounts.putScalar((long)n2, rowCounts.getInt(new int[]{n2}) + 1);
                    continue;
                }
                rowCounts.putScalar((long)n2, rowCounts.getInt(new int[]{n2}) + 1);
                rowCounts.putScalar((long)colP.getInt(new int[]{i}), rowCounts.getInt(new int[]{colP.getInt(new int[]{i})}) + 1);
            }
        }
        int numElements = rowCounts.sumNumber().intValue();
        INDArray offset = Nd4j.create((DataType)DataType.INT, (long[])new long[]{this.N});
        INDArray symRowP = Nd4j.zeros((DataType)DataType.INT, (long[])new long[]{this.N + 1});
        INDArray symColP = Nd4j.create((DataType)DataType.INT, (long[])new long[]{numElements});
        INDArray symValP = Nd4j.create((DataType)valP.dataType(), (long[])new long[]{numElements});
        for (n = 0; n < this.N; ++n) {
            symRowP.putScalar((long)(n + 1), symRowP.getInt(new int[]{n}) + rowCounts.getInt(new int[]{n}));
        }
        for (n = 0; n < this.N; ++n) {
            for (int i = rowP.getInt(new int[]{n}); i < rowP.getInt(new int[]{n + 1}); ++i) {
                int colPI;
                boolean present = false;
                for (int m = rowP.getInt(new int[]{colP.getInt(new int[]{i})}); m < rowP.getInt(new int[]{colP.getInt(new int[]{i}) + 1}); ++m) {
                    if (colP.getInt(new int[]{m}) != n) continue;
                    present = true;
                    if (n > colP.getInt(new int[]{i})) continue;
                    symColP.putScalar((long)(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n})), colP.getInt(new int[]{i}));
                    symColP.putScalar((long)(symRowP.getInt(new int[]{colP.getInt(new int[]{i})}) + offset.getInt(new int[]{colP.getInt(new int[]{i})})), n);
                    symValP.putScalar((long)(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n})), valP.getDouble((long)i) + valP.getDouble((long)m));
                    symValP.putScalar((long)(symRowP.getInt(new int[]{colP.getInt(new int[]{i})}) + offset.getInt(new int[]{colP.getInt(new int[]{i})})), valP.getDouble((long)i) + valP.getDouble((long)m));
                }
                if (!present) {
                    colPI = colP.getInt(new int[]{i});
                    symColP.putScalar((long)(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n})), colPI);
                    symColP.putScalar((long)(symRowP.getInt(new int[]{colP.getInt(new int[]{i})}) + offset.getInt(new int[]{colPI})), n);
                    symValP.putScalar((long)(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n})), valP.getDouble((long)i));
                    symValP.putScalar((long)(symRowP.getInt(new int[]{colPI}) + offset.getInt(new int[]{colPI})), valP.getDouble((long)i));
                }
                if (present && (!present || n > colP.getInt(new int[]{i}))) continue;
                offset.putScalar((long)n, offset.getInt(new int[]{n}) + 1);
                colPI = colP.getInt(new int[]{i});
                if (colPI == n) continue;
                offset.putScalar((long)colPI, offset.getInt(new int[]{colPI}) + 1);
            }
        }
        symValP.divi((Number)2.0);
        return new SymResult(symRowP, symColP, symValP);
    }

    public Pair<INDArray, Double> computeGaussianKernel(INDArray distances, double beta, int k) {
        INDArray currP = Nd4j.create((DataType)distances.dataType(), (long[])new long[]{k});
        for (int m = 0; m < k; ++m) {
            currP.putScalar((long)m, Math.exp(-beta * distances.getDouble((long)(m + 1))));
        }
        double sum = currP.sumNumber().doubleValue() + Double.MIN_VALUE;
        double h = 0.0;
        for (int m = 0; m < k; ++m) {
            h += beta * (distances.getDouble((long)(m + 1)) * currP.getDouble((long)m));
        }
        h = h / sum + Math.log(sum);
        return new Pair((Object)currP, (Object)h);
    }

    public void init() {
    }

    public void setListeners(Collection<TrainingListener> listeners) {
    }

    public void setListeners(TrainingListener ... listeners) {
    }

    private int calculateOutputLength() {
        int ret = 0;
        INDArray rowCounts = Nd4j.create((int)this.N);
        for (int n = 0; n < this.N; ++n) {
            int begin = this.rows.getInt(new int[]{n});
            int end = this.rows.getInt(new int[]{n + 1});
            for (int i = begin; i < end; ++i) {
                boolean present = false;
                for (int m = this.rows.getInt(new int[]{this.cols.getInt(new int[]{i})}); m < this.rows.getInt(new int[]{this.cols.getInt(new int[]{i}) + 1}); ++m) {
                    if (this.cols.getInt(new int[]{m}) != n) continue;
                    present = true;
                }
                if (present) {
                    rowCounts.putScalar((long)n, rowCounts.getDouble((long)n) + 1.0);
                    continue;
                }
                rowCounts.putScalar((long)n, rowCounts.getDouble((long)n) + 1.0);
                rowCounts.putScalar((long)this.cols.getInt(new int[]{i}), rowCounts.getDouble((long)this.cols.getInt(new int[]{i})) + 1.0);
            }
        }
        ret = rowCounts.sum(new int[]{Integer.MAX_VALUE}).getInt(new int[]{0});
        return ret;
    }

    public static void zeroMean(INDArray input) {
        INDArray means = input.mean(new int[]{0});
        input.subiRowVector(means);
    }

    public void fit() {
        if (this.theta == 0.0) {
            log.debug("theta == 0, using decomposed version, might be slow");
            Tsne decomposedTsne = new Tsne(this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.minGain, this.momentum, this.switchMomentumIteration, this.normalize, this.usePca, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity);
            this.Y = decomposedTsne.calculate(this.x, this.numDimensions, this.perplexity);
        } else {
            if (this.Y == null) {
                this.Y = this.initializer.initData();
            }
            this.x.divi(this.x.maxNumber());
            this.computeGaussianPerplexity(this.x, this.perplexity);
            SymResult result = this.symmetrized(this.rows, this.cols, this.vals);
            this.vals = result.vals.divi((Number)result.vals.sumNumber().doubleValue());
            this.rows = result.rows;
            this.cols = result.cols;
            this.vals.muli((Number)12);
            for (int i = 0; i < this.maxIter; ++i) {
                this.step(this.vals, i);
                BarnesHutTsne.zeroMean(this.Y);
                if (i == this.switchMomentumIteration) {
                    this.momentum = this.finalMomentum;
                }
                if (i == this.stopLyingIteration) {
                    this.vals.divi((Number)12);
                }
                if (this.trainingListener == null) continue;
                this.trainingListener.iterationDone((Model)this, i, 0);
            }
        }
    }

    public void update(Gradient gradient) {
    }

    public void step(INDArray p, int i) {
        this.update(this.gradient().getGradientFor(Y_GRAD), Y_GRAD);
    }

    static double sign_tsne(double x) {
        return x == 0.0 ? 0.0 : (x < 0.0 ? -1.0 : 1.0);
    }

    public void update(INDArray gradient, String paramType) {
        INDArray yGrads = gradient;
        if (this.gains == null) {
            this.gains = this.Y.ulike().assign((Number)1.0);
        }
        for (int i = 0; i < yGrads.rows(); ++i) {
            for (int j = 0; j < yGrads.columns(); ++j) {
                if (BarnesHutTsne.sign_tsne(yGrads.getDouble((long)i, (long)j)) == BarnesHutTsne.sign_tsne(this.yIncs.getDouble((long)i, (long)j))) {
                    this.gains.putScalar(new int[]{i, j}, this.gains.getDouble((long)i, (long)j) * 0.8);
                    continue;
                }
                this.gains.putScalar(new int[]{i, j}, this.gains.getDouble((long)i, (long)j) + 0.2);
            }
        }
        BooleanIndexing.replaceWhere((INDArray)this.gains, (Number)this.minGain, (Condition)Conditions.lessThan((Number)this.minGain));
        this.Y.addi(this.yIncs);
        INDArray gradChange = this.gains.mul(yGrads);
        if (this.useAdaGrad) {
            if (this.adaGrad == null) {
                this.adaGrad = new AdaGrad(gradient.shape(), this.learningRate);
                this.adaGrad.setStateViewArray(Nd4j.zeros((long[])gradient.shape()).reshape(1L, gradChange.length()), gradChange.shape(), gradient.ordering(), true);
            }
            gradChange = this.adaGrad.getGradient(gradChange, 0);
        } else {
            gradChange.muli((Number)this.learningRate);
        }
        this.yIncs.muli((Number)this.momentum).subi(gradChange);
    }

    public void saveAsFile(List<String> labels, String path) throws IOException {
        try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)));){
            for (int i = 0; i < this.Y.rows() && i < labels.size(); ++i) {
                String word = labels.get(i);
                if (word == null) continue;
                StringBuilder sb = new StringBuilder();
                INDArray wordVector = this.Y.getRow((long)i);
                int j = 0;
                while ((long)j < wordVector.length()) {
                    sb.append(wordVector.getDouble((long)j));
                    if ((long)j < wordVector.length() - 1L) {
                        sb.append(",");
                    }
                    ++j;
                }
                sb.append(",");
                sb.append(word);
                sb.append("\n");
                write.write(sb.toString());
            }
            write.flush();
        }
    }

    public void saveAsFile(String path) throws IOException {
        try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)));){
            for (int i = 0; i < this.Y.rows(); ++i) {
                StringBuilder sb = new StringBuilder();
                INDArray wordVector = this.Y.getRow((long)i);
                int j = 0;
                while ((long)j < wordVector.length()) {
                    sb.append(wordVector.getDouble((long)j));
                    if ((long)j < wordVector.length() - 1L) {
                        sb.append(",");
                    }
                    ++j;
                }
                sb.append("\n");
                write.write(sb.toString());
            }
            write.flush();
        }
    }

    @Deprecated
    public void plot(INDArray matrix, int nDims, List<String> labels, String path) throws IOException {
        this.fit(matrix, nDims);
        this.saveAsFile(labels, path);
    }

    public double score() {
        INDArray buff = Nd4j.create((int)this.numDimensions);
        AtomicDouble sum_Q = new AtomicDouble(0.0);
        for (int n = 0; n < this.N; ++n) {
            this.tree.computeNonEdgeForces(n, this.theta, buff, sum_Q);
        }
        double C = 0.0;
        INDArray linear = this.Y;
        for (int n = 0; n < this.N; ++n) {
            int begin = this.rows.getInt(new int[]{n});
            int end = this.rows.getInt(new int[]{n + 1});
            int ind1 = n;
            for (int i = begin; i < end; ++i) {
                int ind2 = this.cols.getInt(new int[]{i});
                linear.slice((long)ind1).subi(linear.slice((long)ind2), buff);
                double Q = Transforms.pow((INDArray)buff, (Number)2).sumNumber().doubleValue();
                Q = 1.0 / (1.0 + Q) / sum_Q.doubleValue();
                C += this.vals.getDouble((long)i) * Math.log(this.vals.getDouble((long)i) + Nd4j.EPS_THRESHOLD) / (Q + Nd4j.EPS_THRESHOLD);
            }
        }
        return C;
    }

    public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) {
    }

    public INDArray params() {
        return null;
    }

    public long numParams() {
        return 0L;
    }

    public long numParams(boolean backwards) {
        return 0L;
    }

    public void setParams(INDArray params) {
    }

    public void setParamsViewArray(INDArray params) {
        throw new UnsupportedOperationException();
    }

    public INDArray getGradientsViewArray() {
        throw new UnsupportedOperationException();
    }

    public void setBackpropGradientsViewArray(INDArray gradients) {
        throw new UnsupportedOperationException();
    }

    public void fit(INDArray data) {
        this.x = data;
        this.fit();
    }

    public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr) {
        this.fit(data);
    }

    @Deprecated
    public void fit(INDArray data, int nDims) {
        this.x = data;
        this.numDimensions = nDims;
        this.fit();
    }

    public Gradient gradient() {
        if (this.yIncs == null) {
            this.yIncs = this.Y.like();
        }
        if (this.gains == null) {
            this.gains = this.Y.ulike().assign((Number)1.0);
        }
        AtomicDouble sumQ = new AtomicDouble(0.0);
        INDArray posF = this.Y.like();
        INDArray negF = this.Y.like();
        this.tree = new SpTree(this.Y);
        this.tree.computeEdgeForces(this.rows, this.cols, this.vals, this.N, posF);
        for (int n = 0; n < this.N; ++n) {
            INDArray temp = negF.slice((long)n);
            this.tree.computeNonEdgeForces(n, this.theta, temp, sumQ);
        }
        INDArray dC = posF.subi(negF.divi((Number)sumQ));
        DefaultGradient ret = new DefaultGradient();
        ret.gradientForVariable().put(Y_GRAD, dC);
        return ret;
    }

    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair((Object)this.gradient(), (Object)this.score());
    }

    public int batchSize() {
        return 0;
    }

    public NeuralNetConfiguration conf() {
        return null;
    }

    public void setConf(NeuralNetConfiguration conf) {
    }

    public INDArray getData() {
        return this.Y;
    }

    public void setData(INDArray data) {
        this.Y = data;
    }

    public void setN(int N) {
        this.N = N;
    }

    public void close() {
    }

    public int getMaxIter() {
        return this.maxIter;
    }

    public double getRealMin() {
        return this.realMin;
    }

    public double getInitialMomentum() {
        return this.initialMomentum;
    }

    public double getFinalMomentum() {
        return this.finalMomentum;
    }

    public double getMinGain() {
        return this.minGain;
    }

    public double getMomentum() {
        return this.momentum;
    }

    public int getSwitchMomentumIteration() {
        return this.switchMomentumIteration;
    }

    public boolean isNormalize() {
        return this.normalize;
    }

    public boolean isUsePca() {
        return this.usePca;
    }

    public int getStopLyingIteration() {
        return this.stopLyingIteration;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public AdaGrad getAdaGrad() {
        return this.adaGrad;
    }

    public boolean isUseAdaGrad() {
        return this.useAdaGrad;
    }

    public INDArray getY() {
        return this.Y;
    }

    public int getN() {
        return this.N;
    }

    public INDArray getRows() {
        return this.rows;
    }

    public INDArray getCols() {
        return this.cols;
    }

    public INDArray getVals() {
        return this.vals;
    }

    public INDArray getX() {
        return this.x;
    }

    public SpTree getTree() {
        return this.tree;
    }

    public INDArray getGains() {
        return this.gains;
    }

    public INDArray getYIncs() {
        return this.yIncs;
    }

    public int getVpTreeWorkers() {
        return this.vpTreeWorkers;
    }

    public TrainingListener getTrainingListener() {
        return this.trainingListener;
    }

    public WorkspaceMode getWorkspaceMode() {
        return this.workspaceMode;
    }

    public Initializer getInitializer() {
        return this.initializer;
    }

    public WorkspaceConfiguration getWorkspaceConfigurationFeedForward() {
        return this.workspaceConfigurationFeedForward;
    }

    public void setMaxIter(int maxIter) {
        this.maxIter = maxIter;
    }

    public void setRealMin(double realMin) {
        this.realMin = realMin;
    }

    public void setInitialMomentum(double initialMomentum) {
        this.initialMomentum = initialMomentum;
    }

    public void setFinalMomentum(double finalMomentum) {
        this.finalMomentum = finalMomentum;
    }

    public void setMinGain(double minGain) {
        this.minGain = minGain;
    }

    public void setMomentum(double momentum) {
        this.momentum = momentum;
    }

    public void setSwitchMomentumIteration(int switchMomentumIteration) {
        this.switchMomentumIteration = switchMomentumIteration;
    }

    public void setNormalize(boolean normalize) {
        this.normalize = normalize;
    }

    public void setUsePca(boolean usePca) {
        this.usePca = usePca;
    }

    public void setStopLyingIteration(int stopLyingIteration) {
        this.stopLyingIteration = stopLyingIteration;
    }

    public void setTolerance(double tolerance) {
        this.tolerance = tolerance;
    }

    public void setLearningRate(double learningRate) {
        this.learningRate = learningRate;
    }

    public void setAdaGrad(AdaGrad adaGrad) {
        this.adaGrad = adaGrad;
    }

    public void setUseAdaGrad(boolean useAdaGrad) {
        this.useAdaGrad = useAdaGrad;
    }

    public void setPerplexity(double perplexity) {
        this.perplexity = perplexity;
    }

    public void setY(INDArray Y) {
        this.Y = Y;
    }

    public void setTheta(double theta) {
        this.theta = theta;
    }

    public void setRows(INDArray rows) {
        this.rows = rows;
    }

    public void setCols(INDArray cols) {
        this.cols = cols;
    }

    public void setVals(INDArray vals) {
        this.vals = vals;
    }

    public void setX(INDArray x) {
        this.x = x;
    }

    public void setTree(SpTree tree) {
        this.tree = tree;
    }

    public void setGains(INDArray gains) {
        this.gains = gains;
    }

    public void setVpTreeWorkers(int vpTreeWorkers) {
        this.vpTreeWorkers = vpTreeWorkers;
    }

    public void setTrainingListener(TrainingListener trainingListener) {
        this.trainingListener = trainingListener;
    }

    public void setWorkspaceMode(WorkspaceMode workspaceMode) {
        this.workspaceMode = workspaceMode;
    }

    public void setInitializer(Initializer initializer) {
        this.initializer = initializer;
    }

    public void setWorkspaceConfigurationFeedForward(WorkspaceConfiguration workspaceConfigurationFeedForward) {
        this.workspaceConfigurationFeedForward = workspaceConfigurationFeedForward;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof BarnesHutTsne)) {
            return false;
        }
        BarnesHutTsne other = (BarnesHutTsne)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getMaxIter() != other.getMaxIter()) {
            return false;
        }
        if (Double.compare(this.getRealMin(), other.getRealMin()) != 0) {
            return false;
        }
        if (Double.compare(this.getInitialMomentum(), other.getInitialMomentum()) != 0) {
            return false;
        }
        if (Double.compare(this.getFinalMomentum(), other.getFinalMomentum()) != 0) {
            return false;
        }
        if (Double.compare(this.getMinGain(), other.getMinGain()) != 0) {
            return false;
        }
        if (Double.compare(this.getMomentum(), other.getMomentum()) != 0) {
            return false;
        }
        if (this.getSwitchMomentumIteration() != other.getSwitchMomentumIteration()) {
            return false;
        }
        if (this.isNormalize() != other.isNormalize()) {
            return false;
        }
        if (this.isUsePca() != other.isUsePca()) {
            return false;
        }
        if (this.getStopLyingIteration() != other.getStopLyingIteration()) {
            return false;
        }
        if (Double.compare(this.getTolerance(), other.getTolerance()) != 0) {
            return false;
        }
        if (Double.compare(this.getLearningRate(), other.getLearningRate()) != 0) {
            return false;
        }
        AdaGrad this$adaGrad = this.getAdaGrad();
        AdaGrad other$adaGrad = other.getAdaGrad();
        if (this$adaGrad == null ? other$adaGrad != null : !this$adaGrad.equals(other$adaGrad)) {
            return false;
        }
        if (this.isUseAdaGrad() != other.isUseAdaGrad()) {
            return false;
        }
        if (Double.compare(this.getPerplexity(), other.getPerplexity()) != 0) {
            return false;
        }
        INDArray this$Y = this.getY();
        INDArray other$Y = other.getY();
        if (this$Y == null ? other$Y != null : !this$Y.equals(other$Y)) {
            return false;
        }
        if (this.getN() != other.getN()) {
            return false;
        }
        if (Double.compare(this.getTheta(), other.getTheta()) != 0) {
            return false;
        }
        INDArray this$rows = this.getRows();
        INDArray other$rows = other.getRows();
        if (this$rows == null ? other$rows != null : !this$rows.equals(other$rows)) {
            return false;
        }
        INDArray this$cols = this.getCols();
        INDArray other$cols = other.getCols();
        if (this$cols == null ? other$cols != null : !this$cols.equals(other$cols)) {
            return false;
        }
        INDArray this$vals = this.getVals();
        INDArray other$vals = other.getVals();
        if (this$vals == null ? other$vals != null : !this$vals.equals(other$vals)) {
            return false;
        }
        String this$simiarlityFunction = this.getSimiarlityFunction();
        String other$simiarlityFunction = other.getSimiarlityFunction();
        if (this$simiarlityFunction == null ? other$simiarlityFunction != null : !this$simiarlityFunction.equals(other$simiarlityFunction)) {
            return false;
        }
        if (this.isInvert() != other.isInvert()) {
            return false;
        }
        INDArray this$x = this.getX();
        INDArray other$x = other.getX();
        if (this$x == null ? other$x != null : !this$x.equals(other$x)) {
            return false;
        }
        if (this.getNumDimensions() != other.getNumDimensions()) {
            return false;
        }
        SpTree this$tree = this.getTree();
        SpTree other$tree = other.getTree();
        if (this$tree == null ? other$tree != null : !this$tree.equals(other$tree)) {
            return false;
        }
        INDArray this$gains = this.getGains();
        INDArray other$gains = other.getGains();
        if (this$gains == null ? other$gains != null : !this$gains.equals(other$gains)) {
            return false;
        }
        INDArray this$yIncs = this.getYIncs();
        INDArray other$yIncs = other.getYIncs();
        if (this$yIncs == null ? other$yIncs != null : !this$yIncs.equals(other$yIncs)) {
            return false;
        }
        if (this.getVpTreeWorkers() != other.getVpTreeWorkers()) {
            return false;
        }
        WorkspaceMode this$workspaceMode = this.getWorkspaceMode();
        WorkspaceMode other$workspaceMode = other.getWorkspaceMode();
        if (this$workspaceMode == null ? other$workspaceMode != null : !this$workspaceMode.equals(other$workspaceMode)) {
            return false;
        }
        Initializer this$initializer = this.getInitializer();
        Initializer other$initializer = other.getInitializer();
        if (this$initializer == null ? other$initializer != null : !this$initializer.equals(other$initializer)) {
            return false;
        }
        WorkspaceConfiguration this$workspaceConfigurationFeedForward = this.getWorkspaceConfigurationFeedForward();
        WorkspaceConfiguration other$workspaceConfigurationFeedForward = other.getWorkspaceConfigurationFeedForward();
        return !(this$workspaceConfigurationFeedForward == null ? other$workspaceConfigurationFeedForward != null : !this$workspaceConfigurationFeedForward.equals(other$workspaceConfigurationFeedForward));
    }

    protected boolean canEqual(Object other) {
        return other instanceof BarnesHutTsne;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getMaxIter();
        long $realMin = Double.doubleToLongBits(this.getRealMin());
        result = result * 59 + (int)($realMin >>> 32 ^ $realMin);
        long $initialMomentum = Double.doubleToLongBits(this.getInitialMomentum());
        result = result * 59 + (int)($initialMomentum >>> 32 ^ $initialMomentum);
        long $finalMomentum = Double.doubleToLongBits(this.getFinalMomentum());
        result = result * 59 + (int)($finalMomentum >>> 32 ^ $finalMomentum);
        long $minGain = Double.doubleToLongBits(this.getMinGain());
        result = result * 59 + (int)($minGain >>> 32 ^ $minGain);
        long $momentum = Double.doubleToLongBits(this.getMomentum());
        result = result * 59 + (int)($momentum >>> 32 ^ $momentum);
        result = result * 59 + this.getSwitchMomentumIteration();
        result = result * 59 + (this.isNormalize() ? 79 : 97);
        result = result * 59 + (this.isUsePca() ? 79 : 97);
        result = result * 59 + this.getStopLyingIteration();
        long $tolerance = Double.doubleToLongBits(this.getTolerance());
        result = result * 59 + (int)($tolerance >>> 32 ^ $tolerance);
        long $learningRate = Double.doubleToLongBits(this.getLearningRate());
        result = result * 59 + (int)($learningRate >>> 32 ^ $learningRate);
        AdaGrad $adaGrad = this.getAdaGrad();
        result = result * 59 + ($adaGrad == null ? 43 : $adaGrad.hashCode());
        result = result * 59 + (this.isUseAdaGrad() ? 79 : 97);
        long $perplexity = Double.doubleToLongBits(this.getPerplexity());
        result = result * 59 + (int)($perplexity >>> 32 ^ $perplexity);
        INDArray $Y = this.getY();
        result = result * 59 + ($Y == null ? 43 : $Y.hashCode());
        result = result * 59 + this.getN();
        long $theta = Double.doubleToLongBits(this.getTheta());
        result = result * 59 + (int)($theta >>> 32 ^ $theta);
        INDArray $rows = this.getRows();
        result = result * 59 + ($rows == null ? 43 : $rows.hashCode());
        INDArray $cols = this.getCols();
        result = result * 59 + ($cols == null ? 43 : $cols.hashCode());
        INDArray $vals = this.getVals();
        result = result * 59 + ($vals == null ? 43 : $vals.hashCode());
        String $simiarlityFunction = this.getSimiarlityFunction();
        result = result * 59 + ($simiarlityFunction == null ? 43 : $simiarlityFunction.hashCode());
        result = result * 59 + (this.isInvert() ? 79 : 97);
        INDArray $x = this.getX();
        result = result * 59 + ($x == null ? 43 : $x.hashCode());
        result = result * 59 + this.getNumDimensions();
        SpTree $tree = this.getTree();
        result = result * 59 + ($tree == null ? 43 : $tree.hashCode());
        INDArray $gains = this.getGains();
        result = result * 59 + ($gains == null ? 43 : $gains.hashCode());
        INDArray $yIncs = this.getYIncs();
        result = result * 59 + ($yIncs == null ? 43 : $yIncs.hashCode());
        result = result * 59 + this.getVpTreeWorkers();
        WorkspaceMode $workspaceMode = this.getWorkspaceMode();
        result = result * 59 + ($workspaceMode == null ? 43 : $workspaceMode.hashCode());
        Initializer $initializer = this.getInitializer();
        result = result * 59 + ($initializer == null ? 43 : $initializer.hashCode());
        WorkspaceConfiguration $workspaceConfigurationFeedForward = this.getWorkspaceConfigurationFeedForward();
        result = result * 59 + ($workspaceConfigurationFeedForward == null ? 43 : $workspaceConfigurationFeedForward.hashCode());
        return result;
    }

    public String toString() {
        return "BarnesHutTsne(maxIter=" + this.getMaxIter() + ", realMin=" + this.getRealMin() + ", initialMomentum=" + this.getInitialMomentum() + ", finalMomentum=" + this.getFinalMomentum() + ", minGain=" + this.getMinGain() + ", momentum=" + this.getMomentum() + ", switchMomentumIteration=" + this.getSwitchMomentumIteration() + ", normalize=" + this.isNormalize() + ", usePca=" + this.isUsePca() + ", stopLyingIteration=" + this.getStopLyingIteration() + ", tolerance=" + this.getTolerance() + ", learningRate=" + this.getLearningRate() + ", adaGrad=" + this.getAdaGrad() + ", useAdaGrad=" + this.isUseAdaGrad() + ", perplexity=" + this.getPerplexity() + ", Y=" + this.getY() + ", N=" + this.getN() + ", theta=" + this.getTheta() + ", rows=" + this.getRows() + ", cols=" + this.getCols() + ", vals=" + this.getVals() + ", simiarlityFunction=" + this.getSimiarlityFunction() + ", invert=" + this.isInvert() + ", x=" + this.getX() + ", numDimensions=" + this.getNumDimensions() + ", tree=" + this.getTree() + ", gains=" + this.getGains() + ", yIncs=" + this.getYIncs() + ", vpTreeWorkers=" + this.getVpTreeWorkers() + ", trainingListener=" + this.getTrainingListener() + ", workspaceMode=" + this.getWorkspaceMode() + ", initializer=" + this.getInitializer() + ", workspaceConfigurationFeedForward=" + this.getWorkspaceConfigurationFeedForward() + ")";
    }

    public void setYIncs(INDArray yIncs) {
        this.yIncs = yIncs;
    }

    public static class Builder {
        private int maxIter = 1000;
        private double realMin = 1.0E-12f;
        private double initialMomentum = 0.5;
        private double finalMomentum = 0.8f;
        private double momentum = 0.5;
        private int switchMomentumIteration = 100;
        private boolean normalize = true;
        private int stopLyingIteration = 100;
        private double tolerance = 1.0E-5f;
        private double learningRate = 0.1f;
        private boolean useAdaGrad = false;
        private double perplexity = 30.0;
        private double minGain = 0.01f;
        private double theta = 0.5;
        private boolean invert = true;
        private int numDim = 2;
        private String similarityFunction = Distance.EUCLIDEAN.toString();
        private int vpTreeWorkers = 1;
        protected WorkspaceMode workspaceMode = WorkspaceMode.NONE;
        private INDArray staticInput;

        public Builder vpTreeWorkers(int vpTreeWorkers) {
            this.vpTreeWorkers = vpTreeWorkers;
            return this;
        }

        public Builder staticInit(INDArray staticInput) {
            this.staticInput = staticInput;
            return this;
        }

        public Builder minGain(double minGain) {
            this.minGain = minGain;
            return this;
        }

        public Builder perplexity(double perplexity) {
            this.perplexity = perplexity;
            return this;
        }

        public Builder useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder learningRate(double learningRate) {
            this.learningRate = learningRate;
            return this;
        }

        public Builder tolerance(double tolerance) {
            this.tolerance = tolerance;
            return this;
        }

        public Builder stopLyingIteration(int stopLyingIteration) {
            this.stopLyingIteration = stopLyingIteration;
            return this;
        }

        public Builder normalize(boolean normalize) {
            this.normalize = normalize;
            return this;
        }

        public Builder setMaxIter(int maxIter) {
            this.maxIter = maxIter;
            return this;
        }

        public Builder setRealMin(double realMin) {
            this.realMin = realMin;
            return this;
        }

        public Builder setInitialMomentum(double initialMomentum) {
            this.initialMomentum = initialMomentum;
            return this;
        }

        public Builder setFinalMomentum(double finalMomentum) {
            this.finalMomentum = finalMomentum;
            return this;
        }

        public Builder setMomentum(double momentum) {
            this.momentum = momentum;
            return this;
        }

        public Builder setSwitchMomentumIteration(int switchMomentumIteration) {
            this.switchMomentumIteration = switchMomentumIteration;
            return this;
        }

        public Builder similarityFunction(String similarityFunction) {
            this.similarityFunction = similarityFunction;
            return this;
        }

        public Builder invertDistanceMetric(boolean invert) {
            this.invert = invert;
            return this;
        }

        public Builder theta(double theta) {
            this.theta = theta;
            return this;
        }

        public Builder numDimension(int numDim) {
            this.numDim = numDim;
            return this;
        }

        public Builder workspaceMode(WorkspaceMode workspaceMode) {
            this.workspaceMode = workspaceMode;
            return this;
        }

        public BarnesHutTsne build() {
            return new BarnesHutTsne(this.numDim, this.similarityFunction, this.theta, this.invert, this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.momentum, this.switchMomentumIteration, this.normalize, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity, null, this.minGain, this.vpTreeWorkers, this.workspaceMode, this.staticInput);
        }
    }

    public class Initializer {
        private INDArray staticData;

        public Initializer() {
        }

        public Initializer(INDArray input) {
            this.staticData = input;
        }

        public INDArray initData() {
            if (this.staticData != null) {
                return this.staticData.dup();
            }
            return Nd4j.randn((DataType)BarnesHutTsne.this.x.dataType(), (long[])new long[]{BarnesHutTsne.this.x.rows(), BarnesHutTsne.this.numDimensions}).muli((Number)Float.valueOf(0.001f));
        }
    }

    static class SymResult {
        INDArray rows;
        INDArray cols;
        INDArray vals;

        public INDArray getRows() {
            return this.rows;
        }

        public INDArray getCols() {
            return this.cols;
        }

        public INDArray getVals() {
            return this.vals;
        }

        public void setRows(INDArray rows) {
            this.rows = rows;
        }

        public void setCols(INDArray cols) {
            this.cols = cols;
        }

        public void setVals(INDArray vals) {
            this.vals = vals;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof SymResult)) {
                return false;
            }
            SymResult other = (SymResult)o;
            if (!other.canEqual(this)) {
                return false;
            }
            INDArray this$rows = this.getRows();
            INDArray other$rows = other.getRows();
            if (this$rows == null ? other$rows != null : !this$rows.equals(other$rows)) {
                return false;
            }
            INDArray this$cols = this.getCols();
            INDArray other$cols = other.getCols();
            if (this$cols == null ? other$cols != null : !this$cols.equals(other$cols)) {
                return false;
            }
            INDArray this$vals = this.getVals();
            INDArray other$vals = other.getVals();
            return !(this$vals == null ? other$vals != null : !this$vals.equals(other$vals));
        }

        protected boolean canEqual(Object other) {
            return other instanceof SymResult;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            INDArray $rows = this.getRows();
            result = result * 59 + ($rows == null ? 43 : $rows.hashCode());
            INDArray $cols = this.getCols();
            result = result * 59 + ($cols == null ? 43 : $cols.hashCode());
            INDArray $vals = this.getVals();
            result = result * 59 + ($vals == null ? 43 : $vals.hashCode());
            return result;
        }

        public String toString() {
            return "BarnesHutTsne.SymResult(rows=" + this.getRows() + ", cols=" + this.getCols() + ", vals=" + this.getVals() + ")";
        }

        public SymResult(INDArray rows, INDArray cols, INDArray vals) {
            this.rows = rows;
            this.cols = cols;
            this.vals = vals;
        }
    }
}

