/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.annotations.VisibleForTesting;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.ConcurrentNeighborSet;
import io.github.jbellis.jvector.graph.GraphIndex;
import io.github.jbellis.jvector.graph.GraphSearcher;
import io.github.jbellis.jvector.graph.NodeArray;
import io.github.jbellis.jvector.graph.NodeSimilarity;
import io.github.jbellis.jvector.graph.NodesIterator;
import io.github.jbellis.jvector.graph.OnHeapGraphIndex;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.util.AtomicFixedBitSet;
import io.github.jbellis.jvector.util.BitSet;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
import io.github.jbellis.jvector.vector.VectorEncoding;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import java.io.IOException;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.agrona.collections.IntArrayQueue;
import org.agrona.collections.IntHashSet;

public class GraphIndexBuilder<T> {
    private final int beamWidth;
    private final ExplicitThreadLocal<NodeArray> naturalScratch;
    private final ExplicitThreadLocal<NodeArray> concurrentScratch;
    private final VectorSimilarityFunction similarityFunction;
    private final float neighborOverflow;
    private final float alpha;
    private final VectorEncoding vectorEncoding;
    private final ExplicitThreadLocal<GraphSearcher<?>> graphSearcher;
    @VisibleForTesting
    final OnHeapGraphIndex<T> graph;
    private final ConcurrentSkipListSet<Integer> insertionsInProgress = new ConcurrentSkipListSet();
    private final Supplier<RandomAccessVectorValues<T>> vectors;
    private final Supplier<RandomAccessVectorValues<T>> vectorsCopy;
    private final int dimension;
    private final NodeSimilarity similarity;
    private final ForkJoinPool simdExecutor;
    private final ForkJoinPool parallelExecutor;
    private final AtomicInteger updateEntryNodeIn = new AtomicInteger(10000);

    public GraphIndexBuilder(RandomAccessVectorValues<T> vectorValues, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction, int M, int beamWidth, float neighborOverflow, float alpha) {
        this(vectorValues, vectorEncoding, similarityFunction, M, beamWidth, neighborOverflow, alpha, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool());
    }

    public GraphIndexBuilder(RandomAccessVectorValues<T> vectorValues, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction, int M, int beamWidth, float neighborOverflow, float alpha, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) {
        this.vectors = vectorValues.threadLocalSupplier();
        this.vectorsCopy = vectorValues.threadLocalSupplier();
        this.dimension = vectorValues.dimension();
        this.vectorEncoding = Objects.requireNonNull(vectorEncoding);
        this.similarityFunction = Objects.requireNonNull(similarityFunction);
        this.neighborOverflow = neighborOverflow;
        this.alpha = alpha;
        if (M <= 0) {
            throw new IllegalArgumentException("maxConn must be positive");
        }
        if (beamWidth <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.beamWidth = beamWidth;
        this.simdExecutor = simdExecutor;
        this.parallelExecutor = parallelExecutor;
        this.similarity = node1 -> {
            RandomAccessVectorValues<T> v = this.vectors.get();
            RandomAccessVectorValues<T> vc = this.vectorsCopy.get();
            T v1 = v.vectorValue(node1);
            return node2 -> this.scoreBetween(v1, vc.vectorValue(node2));
        };
        this.graph = new OnHeapGraphIndex(M, (node, m) -> new ConcurrentNeighborSet((int)node, (int)m, this.similarity, alpha));
        this.graphSearcher = ExplicitThreadLocal.withInitial(() -> new GraphSearcher.Builder<T>(this.graph.getView()).withConcurrentUpdates().build());
        this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(Math.max(beamWidth, M + 1)));
        this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(Math.max(beamWidth, M + 1)));
    }

    public OnHeapGraphIndex<T> build() {
        RandomAccessVectorValues<T> v = this.vectors.get();
        int size = v.size();
        ((ForkJoinTask)this.simdExecutor.submit(() -> IntStream.range(0, size).parallel().forEach(i -> this.addGraphNode(i, this.vectors.get())))).join();
        this.cleanup();
        return this.graph;
    }

    public void cleanup() {
        if (this.graph.size() == 0) {
            return;
        }
        this.graph.validateEntryNode();
        this.removeDeletedNodes();
        if (this.graph.size() == 0) {
            return;
        }
        ((ForkJoinTask)this.parallelExecutor.submit(() -> IntStream.range(0, this.graph.getIdUpperBound()).parallel().forEach(i -> {
            ConcurrentNeighborSet neighbors = this.graph.getNeighbors(i);
            if (neighbors != null) {
                neighbors.enforceDegree();
            }
        }))).join();
        this.reconnectOrphanedNodes();
        this.graph.updateEntryNode(this.approximateMedioid());
        this.updateEntryNodeIn.set(this.graph.size());
    }

    private void reconnectOrphanedNodes() {
        ConcurrentHashMap searchPathNeighbors = new ConcurrentHashMap();
        for (int i = 0; i < 3; ++i) {
            AtomicFixedBitSet connectedNodes = new AtomicFixedBitSet(this.graph.getIdUpperBound());
            connectedNodes.set(this.graph.entry());
            NodeArray entryNeighbors = this.graph.getNeighbors(this.graph.entry()).getCurrent();
            ((ForkJoinTask)this.parallelExecutor.submit(() -> IntStream.range(0, entryNeighbors.size).parallel().forEach(node -> this.findConnected(connectedNodes, entryNeighbors.node[node])))).join();
            AtomicInteger nReconnected = new AtomicInteger();
            ConcurrentHashMap.KeySetView connectionTargets = ConcurrentHashMap.newKeySet();
            this.simdExecutor.submit(() -> IntStream.range(0, this.graph.getIdUpperBound()).parallel().forEach(node -> {
                if (connectedNodes.get(node) || !this.graph.containsNode(node)) {
                    return;
                }
                nReconnected.incrementAndGet();
                NodeArray neighbors = this.graph.getNeighbors(node).getCurrent();
                if (this.connectToClosestNeighbor(node, neighbors, connectionTargets)) {
                    return;
                }
                neighbors = (NodeArray)searchPathNeighbors.get(node);
                if (neighbors == null) {
                    GraphSearcher<?> gs = this.graphSearcher.get();
                    RandomAccessVectorValues<T> v1 = this.vectors.get();
                    RandomAccessVectorValues<T> v2 = this.vectorsCopy.get();
                    Bits notSelfBits = GraphIndexBuilder.createNotSelfBits(node);
                    T value = v1.vectorValue(node);
                    NodeSimilarity.ExactScoreFunction scoreFunction = i1 -> this.scoreBetween(v2.vectorValue(i1), value);
                    int ep = this.graph.entry();
                    SearchResult result = gs.searchInternal(scoreFunction, null, this.beamWidth, 0.0f, 0.0f, ep, notSelfBits);
                    neighbors = new NodeArray(result.getNodes().length);
                    GraphIndexBuilder.toScratchCandidates(result.getNodes(), neighbors);
                    searchPathNeighbors.put(node, neighbors);
                }
                this.connectToClosestNeighbor(node, neighbors, connectionTargets);
            }));
            if (nReconnected.get() == 0) break;
        }
    }

    private boolean connectToClosestNeighbor(int node, NodeArray neighbors, Set<Integer> connectionTargets) {
        for (int i = 0; i < neighbors.size; ++i) {
            int neighborNode = neighbors.node[i];
            float neighborScore = neighbors.score[i];
            if (!connectionTargets.add(neighborNode)) continue;
            this.graph.getNeighbors(neighborNode).insertNotDiverse(node, neighborScore, true);
            return true;
        }
        return false;
    }

    private void findConnected(AtomicFixedBitSet connectedNodes, int start) {
        IntArrayQueue queue = new IntArrayQueue();
        queue.add(Integer.valueOf(start));
        try (GraphIndex.View<T> view = this.graph.getView();){
            while (!queue.isEmpty()) {
                int next = queue.pollInt();
                if (connectedNodes.getAndSet(next)) continue;
                NodesIterator it = view.getNeighborsIterator(next);
                while (it.hasNext()) {
                    queue.addInt(it.nextInt());
                }
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public OnHeapGraphIndex<T> getGraph() {
        return this.graph;
    }

    public int insertsInProgress() {
        return this.insertionsInProgress.size();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public long addGraphNode(int node, RandomAccessVectorValues<T> vectors) {
        T value = vectors.vectorValue(node);
        ConcurrentNeighborSet newNodeNeighbors = this.graph.addNode(node);
        this.insertionsInProgress.add(node);
        Object inProgressBefore = this.insertionsInProgress.clone();
        try {
            GraphSearcher<?> gs = this.graphSearcher.get();
            RandomAccessVectorValues<T> vc = this.vectorsCopy.get();
            NodeArray naturalScratchPooled = this.naturalScratch.get();
            NodeArray concurrentScratchPooled = this.concurrentScratch.get();
            int ep = this.graph.entry();
            NodeSimilarity.ExactScoreFunction scoreFunction = i -> this.scoreBetween(vc.vectorValue(i), value);
            ExcludingBits bits = new ExcludingBits(node);
            SearchResult result = gs.searchInternal(scoreFunction, null, this.beamWidth, 0.0f, 0.0f, ep, bits);
            NodeArray natural = GraphIndexBuilder.toScratchCandidates(result.getNodes(), naturalScratchPooled);
            NodeArray concurrent = this.getConcurrentCandidates(node, (Set<Integer>)inProgressBefore, concurrentScratchPooled, vectors, vc);
            this.updateNeighbors(newNodeNeighbors, natural, concurrent);
            this.maybeUpdateEntryPoint(node);
            this.maybeImproveOlderNode();
        }
        finally {
            this.insertionsInProgress.remove(node);
        }
        return this.graph.ramBytesUsedOneNode();
    }

    private void maybeImproveOlderNode() {
        if (this.dimension <= 3 && this.graph.size() > 20000) {
            for (int i = 0; i < 3; ++i) {
                int olderNode = ThreadLocalRandom.current().nextInt(this.graph.size());
                if (!this.graph.containsNode(olderNode)) continue;
                this.improveConnections(olderNode);
                break;
            }
        }
    }

    private void maybeUpdateEntryPoint(int node) {
        this.graph.maybeSetInitialEntryNode(node);
        if (this.updateEntryNodeIn.decrementAndGet() == 0) {
            int newEntryNode = this.approximateMedioid();
            this.graph.updateEntryNode(newEntryNode);
            this.improveConnections(newEntryNode);
            this.updateEntryNodeIn.addAndGet(this.graph.size());
        }
    }

    public void improveConnections(int node) {
        RandomAccessVectorValues<T> pv = this.vectors.get();
        GraphSearcher<?> gs = this.graphSearcher.get();
        RandomAccessVectorValues<T> vc = this.vectorsCopy.get();
        NodeArray naturalScratchPooled = this.naturalScratch.get();
        T value = pv.vectorValue(node);
        int ep = this.graph.entry();
        NodeSimilarity.ExactScoreFunction scoreFunction = i -> this.scoreBetween(vc.vectorValue(i), value);
        ExcludingBits bits = new ExcludingBits(node);
        SearchResult result = gs.searchInternal(scoreFunction, null, this.beamWidth, 0.0f, 0.0f, ep, bits);
        NodeArray natural = GraphIndexBuilder.toScratchCandidates(result.getNodes(), naturalScratchPooled);
        this.updateNeighbors(this.graph.getNeighbors(node), natural, NodeArray.EMPTY);
    }

    public void markNodeDeleted(int node) {
        this.graph.markDeleted(node);
    }

    private long removeDeletedNodes() {
        BitSet deletedNodes = this.graph.getDeletedNodes();
        int nRemoved = deletedNodes.cardinality();
        if (nRemoved == 0) {
            return 0L;
        }
        int i = deletedNodes.nextSetBit(0);
        while (i != Integer.MAX_VALUE) {
            boolean success = this.graph.removeNode(i);
            assert (success) : String.format("Node %d marked deleted but not present", i);
            i = deletedNodes.nextSetBit(i + 1);
        }
        int[] liveNodes = this.graph.rawNodes();
        IntHashSet affectedLiveNodes = new IntHashSet();
        Random R = new Random();
        RandomAccessVectorValues<T> v1 = this.vectors.get();
        RandomAccessVectorValues<T> v2 = this.vectorsCopy.get();
        for (int node : liveNodes) {
            assert (!deletedNodes.get(node));
            ConcurrentNeighborSet neighbors = this.graph.getNeighbors(node);
            if (!neighbors.removeDeletedNeighbors(deletedNodes)) continue;
            affectedLiveNodes.add(node);
            int minConnections = 1 + this.graph.maxDegree() / 2;
            if (neighbors.size() >= minConnections) continue;
            NodeArray randomConnections = new NodeArray(this.graph.maxDegree() - neighbors.size());
            for (int i2 = 0; i2 < 2 * this.graph.maxDegree(); ++i2) {
                int randomNode = liveNodes[R.nextInt(liveNodes.length)];
                if (randomNode != node && !randomConnections.contains(randomNode)) {
                    float score = this.scoreBetween(v1.vectorValue(node), v2.vectorValue(randomNode));
                    randomConnections.insertSorted(randomNode, score);
                }
                if (randomConnections.size == randomConnections.node.length) break;
            }
            neighbors.padWithRandom(randomConnections);
        }
        if (deletedNodes.get(this.graph.entry())) {
            if (this.graph.size() > 0) {
                this.graph.updateEntryNode(this.graph.getNodes().nextInt());
            } else {
                this.graph.updateEntryNode(-1);
            }
        }
        ((ForkJoinTask)this.simdExecutor.submit(() -> ((Stream)affectedLiveNodes.stream().parallel()).forEach(this::addNNDescentConnections))).join();
        deletedNodes.clear();
        return (long)nRemoved * this.graph.ramBytesUsedOneNode();
    }

    private void addNNDescentConnections(int node) {
        Bits notSelfBits = GraphIndexBuilder.createNotSelfBits(node);
        GraphSearcher<?> gs = this.graphSearcher.get();
        RandomAccessVectorValues<T> v1 = this.vectors.get();
        RandomAccessVectorValues<T> v2 = this.vectorsCopy.get();
        NodeArray scratch = this.naturalScratch.get();
        T value = v1.vectorValue(node);
        NodeSimilarity.ExactScoreFunction scoreFunction = i -> this.scoreBetween(v2.vectorValue(i), value);
        int ep = this.graph.entry();
        SearchResult result = gs.searchInternal(scoreFunction, null, this.beamWidth, 0.0f, 0.0f, ep, notSelfBits);
        NodeArray candidates = GraphIndexBuilder.toScratchCandidates(result.getNodes(), scratch);
        this.updateNeighbors(this.graph.getNeighbors(node), candidates, NodeArray.EMPTY);
    }

    private static Bits createNotSelfBits(final int node) {
        return new Bits(){

            @Override
            public boolean get(int index) {
                return index != node;
            }

            @Override
            public int length() {
                throw new UnsupportedOperationException();
            }
        };
    }

    private int approximateMedioid() {
        assert (this.graph.size() > 0);
        if (this.vectorEncoding != VectorEncoding.FLOAT32) {
            return this.graph.entry();
        }
        GraphSearcher<?> gs = this.graphSearcher.get();
        RandomAccessVectorValues<T> vc = this.vectorsCopy.get();
        float[] centroid = new float[this.dimension];
        NodesIterator it = this.graph.getNodes();
        while (it.hasNext()) {
            int node = it.nextInt();
            VectorUtil.addInPlace(centroid, (float[])vc.vectorValue(node));
        }
        VectorUtil.scale(centroid, 1.0f / (float)this.graph.size());
        NodeSimilarity.ExactScoreFunction scoreFunction = i -> this.scoreBetween(vc.vectorValue(i), centroid);
        int ep = this.graph.entry();
        SearchResult result = gs.searchInternal(scoreFunction, null, this.beamWidth, 0.0f, 0.0f, ep, Bits.ALL);
        return result.getNodes()[0].node;
    }

    private void updateNeighbors(ConcurrentNeighborSet neighbors, NodeArray natural, NodeArray concurrent) {
        neighbors.insertDiverse(natural, concurrent);
        neighbors.backlink(this.graph::getNeighbors, this.neighborOverflow);
    }

    private static NodeArray toScratchCandidates(SearchResult.NodeScore[] candidates, NodeArray scratch) {
        scratch.clear();
        for (SearchResult.NodeScore candidate : candidates) {
            scratch.addInOrder(candidate.node, candidate.score);
        }
        return scratch;
    }

    private NodeArray getConcurrentCandidates(int newNode, Set<Integer> inProgress, NodeArray scratch, RandomAccessVectorValues<T> values, RandomAccessVectorValues<T> valuesCopy) {
        scratch.clear();
        for (Integer n : inProgress) {
            if (n == newNode) continue;
            scratch.insertSorted(n, this.scoreBetween(values.vectorValue(newNode), valuesCopy.vectorValue(n)));
        }
        return scratch;
    }

    protected float scoreBetween(T v1, T v2) {
        return GraphIndexBuilder.scoreBetween(this.vectorEncoding, this.similarityFunction, v1, v2);
    }

    static <T> float scoreBetween(VectorEncoding encoding, VectorSimilarityFunction similarityFunction, T v1, T v2) {
        switch (encoding) {
            case BYTE: {
                return similarityFunction.compare((byte[])v1, (byte[])v2);
            }
            case FLOAT32: {
                return similarityFunction.compare((float[])v1, (float[])v2);
            }
        }
        throw new IllegalArgumentException();
    }

    public void load(RandomAccessReader in) throws IOException {
        if (this.graph.size() != 0) {
            throw new IllegalStateException("Cannot load into a non-empty graph");
        }
        int size = in.readInt();
        int entryNode = in.readInt();
        int maxDegree = in.readInt();
        for (int i = 0; i < size; ++i) {
            int node = in.readInt();
            int nNeighbors = in.readInt();
            NodeArray ca = new NodeArray(maxDegree);
            for (int j = 0; j < nNeighbors; ++j) {
                int neighbor = in.readInt();
                ca.addInOrder(neighbor, this.similarity.score(node, neighbor));
            }
            this.graph.addNode(node, new ConcurrentNeighborSet(node, maxDegree, this.similarity, this.alpha, ca));
        }
        this.graph.updateEntryNode(entryNode);
    }

    private static class ExcludingBits
    implements Bits {
        private final int excluded;

        public ExcludingBits(int excluded) {
            this.excluded = excluded;
        }

        @Override
        public boolean get(int index) {
            return index != this.excluded;
        }

        @Override
        public int length() {
            throw new UnsupportedOperationException();
        }
    }
}

