/*
 * Decompiled with CFR 0.152.
 */
package tagbio.umap;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import tagbio.umap.FlatTree;
import tagbio.umap.Heap;
import tagbio.umap.MathUtils;
import tagbio.umap.Matrix;
import tagbio.umap.NearestNeighborDescent;
import tagbio.umap.UmapProgress;
import tagbio.umap.Utils;
import tagbio.umap.metric.Metric;

class ParallelNearestNeighborDescent
extends NearestNeighborDescent {
    private final int mThreads;

    ParallelNearestNeighborDescent(Metric metric, int threads) {
        super(metric);
        if (threads < 1) {
            throw new IllegalArgumentException();
        }
        this.mThreads = threads;
    }

    @Override
    Heap descent(Matrix data, int nNeighbors, Random random, int maxCandidates, boolean rpTreeInit, int nIters, List<FlatTree> forest) {
        return this.descent(data, nNeighbors, random, maxCandidates, rpTreeInit, nIters, forest, 0.001f, 0.5f);
    }

    @Override
    Heap descent(Matrix data, int nNeighbors, Random random, int maxCandidates, boolean rpTreeInit, int nIters, List<FlatTree> forest, float delta, float rho) {
        ExecutorService executor = Executors.newFixedThreadPool(this.mThreads);
        try {
            UmapProgress.incTotal(nIters);
            ArrayList<Future<Integer>> futures = new ArrayList<Future<Integer>>();
            int nVertices = data.rows();
            Heap currentGraph = new Heap(data.rows(), nNeighbors);
            int jobs = (int)((double)this.mThreads * (1.0 + MathUtils.log2(this.mThreads)));
            int chunkSize = (nVertices + jobs - 1) / jobs;
            for (int t = 0; t < jobs; ++t) {
                int lo = t * chunkSize;
                int hi = Math.min((t + 1) * chunkSize, nVertices);
                futures.add(executor.submit(() -> {
                    for (int i = lo; i < hi; ++i) {
                        float[] iRow = data.row(i);
                        for (int index : Utils.rejectionSample(nNeighbors, data.rows(), random)) {
                            float d = this.mMetric.distance(iRow, data.row(index));
                            currentGraph.push(i, d, index, true);
                            currentGraph.push(index, d, i, true);
                        }
                    }
                    return 0;
                }));
            }
            ParallelNearestNeighborDescent.waitForFutures(futures);
            if (rpTreeInit) {
                int cs = (forest.size() + jobs - 1) / jobs;
                for (int t = 0; t < jobs; ++t) {
                    int lo = t * cs;
                    int hi = Math.min((t + 1) * cs, forest.size());
                    futures.add(executor.submit(() -> {
                        for (int l = lo; l < hi; ++l) {
                            for (int[] leaf : ((FlatTree)forest.get(l)).getIndices()) {
                                for (int i = 0; i < leaf.length; ++i) {
                                    float[] iRow = data.row(leaf[i]);
                                    for (int j = i + 1; j < leaf.length; ++j) {
                                        float d = this.mMetric.distance(iRow, data.row(leaf[j]));
                                        currentGraph.push(leaf[i], d, leaf[j], true);
                                        currentGraph.push(leaf[j], d, leaf[i], true);
                                    }
                                }
                            }
                        }
                        return 0;
                    }));
                }
                ParallelNearestNeighborDescent.waitForFutures(futures);
            }
            for (int n = 0; n < nIters; ++n) {
                if (this.mVerbose) {
                    Utils.message("NearestNeighborDescent: " + (n + 1) + " / " + nIters);
                }
                Heap candidateNeighbors = currentGraph.buildCandidates(nVertices, nNeighbors, maxCandidates, random);
                for (int t = 0; t < jobs; ++t) {
                    int lo = t * chunkSize;
                    int hi = Math.min((t + 1) * chunkSize, nVertices);
                    futures.add(executor.submit(() -> {
                        boolean[] rejectStatus = new boolean[maxCandidates];
                        int c = 0;
                        for (int i = lo; i < hi; ++i) {
                            int j;
                            for (j = 0; j < maxCandidates; ++j) {
                                rejectStatus[j] = random.nextFloat() < rho;
                            }
                            for (j = 0; j < maxCandidates; ++j) {
                                int p = candidateNeighbors.index(i, j);
                                if (p < 0) continue;
                                for (int k = 0; k <= j; ++k) {
                                    int q = candidateNeighbors.index(i, k);
                                    if (q < 0 || rejectStatus[j] && rejectStatus[k] || !candidateNeighbors.isNew(i, j) && !candidateNeighbors.isNew(i, k)) continue;
                                    float d = this.mMetric.distance(data.row(p), data.row(q));
                                    if (currentGraph.push(p, d, q, true)) {
                                        ++c;
                                    }
                                    if (!currentGraph.push(q, d, p, true)) continue;
                                    ++c;
                                }
                            }
                        }
                        return c;
                    }));
                }
                int c = ParallelNearestNeighborDescent.waitForFutures(futures);
                if ((float)c <= delta * (float)nNeighbors * (float)data.rows()) {
                    UmapProgress.update(nIters - n);
                    break;
                }
                UmapProgress.update();
            }
            Heap heap = currentGraph.deheapSort();
            return heap;
        }
        catch (InterruptedException | ExecutionException ex) {
            throw new RuntimeException(ex);
        }
        finally {
            executor.shutdown();
        }
    }

    private static int waitForFutures(List<Future<Integer>> futures) throws InterruptedException, ExecutionException {
        int c = 0;
        for (Future<Integer> future : futures) {
            c += future.get().intValue();
        }
        futures.clear();
        return c;
    }
}

