/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.annotations.VisibleForTesting;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.ConcurrentNeighborMap;
import io.github.jbellis.jvector.graph.GraphIndex;
import io.github.jbellis.jvector.graph.GraphSearcher;
import io.github.jbellis.jvector.graph.NodeArray;
import io.github.jbellis.jvector.graph.NodesIterator;
import io.github.jbellis.jvector.graph.OnHeapGraphIndex;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.ExceptionUtils;
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
import io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import java.io.Closeable;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.agrona.collections.IntArrayList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GraphIndexBuilder
implements Closeable {
    private static final int BUILD_BATCH_SIZE = 50;
    private static final Logger logger = LoggerFactory.getLogger(GraphIndexBuilder.class);
    private final int beamWidth;
    private final ExplicitThreadLocal<NodeArray> naturalScratch;
    private final ExplicitThreadLocal<NodeArray> concurrentScratch;
    private final int dimension;
    private final float neighborOverflow;
    private final float alpha;
    private final boolean addHierarchy;
    @VisibleForTesting
    final OnHeapGraphIndex graph;
    private final ConcurrentSkipListSet<GraphIndex.NodeAtLevel> insertionsInProgress = new ConcurrentSkipListSet();
    private final BuildScoreProvider scoreProvider;
    private final ForkJoinPool simdExecutor;
    private final ForkJoinPool parallelExecutor;
    private final ExplicitThreadLocal<GraphSearcher> searchers;
    private final Random rng;

    public GraphIndexBuilder(RandomAccessVectorValues vectorValues, VectorSimilarityFunction similarityFunction, int M, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy) {
        this(BuildScoreProvider.randomAccessScoreProvider(vectorValues, similarityFunction), vectorValues.dimension(), M, beamWidth, neighborOverflow, alpha, addHierarchy);
    }

    public GraphIndexBuilder(BuildScoreProvider scoreProvider, int dimension, int M, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy) {
        this(scoreProvider, dimension, M, beamWidth, neighborOverflow, alpha, addHierarchy, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool());
    }

    public GraphIndexBuilder(BuildScoreProvider scoreProvider, int dimension, int M, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) {
        this(scoreProvider, dimension, List.of(Integer.valueOf(M)), beamWidth, neighborOverflow, alpha, addHierarchy, simdExecutor, parallelExecutor);
    }

    public GraphIndexBuilder(BuildScoreProvider scoreProvider, int dimension, List<Integer> maxDegrees, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy) {
        this(scoreProvider, dimension, maxDegrees, beamWidth, neighborOverflow, alpha, addHierarchy, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool());
    }

    public GraphIndexBuilder(BuildScoreProvider scoreProvider, int dimension, List<Integer> maxDegrees, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) {
        if (maxDegrees.stream().anyMatch(i -> i <= 0)) {
            throw new IllegalArgumentException("layer degrees must be positive");
        }
        if (maxDegrees.size() > 1 && !addHierarchy) {
            throw new IllegalArgumentException("Cannot specify multiple max degrees with addHierarchy=False");
        }
        if (beamWidth <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        if (neighborOverflow < 1.0f) {
            throw new IllegalArgumentException("neighborOverflow must be >= 1.0");
        }
        if (alpha <= 0.0f) {
            throw new IllegalArgumentException("alpha must be positive");
        }
        this.scoreProvider = scoreProvider;
        this.dimension = dimension;
        this.neighborOverflow = neighborOverflow;
        this.alpha = alpha;
        this.addHierarchy = addHierarchy;
        this.beamWidth = beamWidth;
        this.simdExecutor = simdExecutor;
        this.parallelExecutor = parallelExecutor;
        this.graph = new OnHeapGraphIndex(maxDegrees, neighborOverflow, scoreProvider, alpha, 50);
        this.searchers = ExplicitThreadLocal.withInitial(() -> {
            GraphSearcher gs = new GraphSearcher(this.graph);
            gs.usePruning(false);
            return gs;
        });
        this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(Math.max(beamWidth, this.graph.maxDegree() + 1)));
        this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(Math.max(beamWidth, this.graph.maxDegree() + 1)));
        this.rng = new Random(0L);
    }

    public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) {
        GraphIndexBuilder newBuilder = new GraphIndexBuilder(newProvider, other.dimension, (List<Integer>)other.graph.maxDegrees, other.beamWidth, other.neighborOverflow, other.alpha, other.addHierarchy, other.simdExecutor, other.parallelExecutor);
        ((ForkJoinTask)other.parallelExecutor.submit(() -> IntStream.range(0, other.graph.getIdUpperBound()).parallel().forEach(i -> {
            int maxLayer = -1;
            int lvl = 0;
            while (lvl < other.graph.layers.size() && other.graph.getNeighbors(lvl, i) != null) {
                maxLayer = lvl++;
            }
            if (maxLayer < 0) {
                return;
            }
            ScoreFunction sf = newProvider.searchProviderFor(i).scoreFunction();
            for (int lvl2 = 0; lvl2 <= maxLayer; ++lvl2) {
                ConcurrentNeighborMap.Neighbors oldNeighbors = other.graph.getNeighbors(lvl2, i);
                NodeArray newNeighbors = new NodeArray(oldNeighbors.size());
                NodesIterator it = oldNeighbors.iterator();
                while (it.hasNext()) {
                    int neighbor = it.nextInt();
                    newNeighbors.insertSorted(neighbor, sf.similarityTo(neighbor));
                }
                newBuilder.graph.addNode(lvl2, i, newNeighbors);
            }
        }))).join();
        newBuilder.graph.updateEntryNode(other.graph.entry());
        return newBuilder;
    }

    public OnHeapGraphIndex build(RandomAccessVectorValues ravv) {
        Supplier<RandomAccessVectorValues> vv = ravv.threadLocalSupplier();
        int size = ravv.size();
        ((ForkJoinTask)this.simdExecutor.submit(() -> IntStream.range(0, size).parallel().forEach(arg_0 -> this.lambda$build$6((Supplier)vv, arg_0)))).join();
        this.cleanup();
        return this.graph;
    }

    public void cleanup() {
        if (this.graph.size(0) == 0) {
            return;
        }
        this.graph.validateEntryNode();
        this.removeDeletedNodes();
        if (this.graph.size(0) == 0) {
            return;
        }
        if (this.graph.getMaxLevel() > 0) {
            ((ForkJoinTask)this.parallelExecutor.submit(() -> this.graph.nodeStream(1).parallel().forEach(this::improveConnections))).join();
        }
        ((ForkJoinTask)this.parallelExecutor.submit(() -> IntStream.range(0, this.graph.getIdUpperBound()).parallel().forEach(id -> {
            for (int layer = 0; layer < this.graph.layers.size(); ++layer) {
                this.graph.layers.get(layer).enforceDegree(id);
            }
        }))).join();
    }

    private void improveConnections(int node) {
        SearchScoreProvider ssp = this.scoreProvider.searchProviderFor(node);
        ExcludingBits bits = new ExcludingBits(node);
        try (GraphSearcher gs = this.searchers.get();){
            gs.initializeInternal(ssp, this.graph.entry(), bits);
            Bits acceptedBits = Bits.intersectionOf(bits, gs.getView().liveNodes());
            for (int lvl = this.graph.entry().level; lvl >= 0; --lvl) {
                ssp = this.scoreProvider.searchProviderFor(node);
                if (this.graph.layers.get(lvl).get(node) != null) {
                    gs.searchOneLayer(ssp, this.beamWidth, 0.0f, lvl, acceptedBits);
                    NodeArray candidates = new NodeArray(gs.approximateResults.size());
                    gs.approximateResults.foreach(candidates::insertSorted);
                    ConcurrentNeighborMap.Neighbors newNeighbors = this.graph.layers.get(lvl).insertDiverse(node, candidates);
                    this.graph.layers.get(lvl).backlink(newNeighbors, node, this.neighborOverflow);
                } else {
                    gs.searchOneLayer(ssp, 1, 0.0f, lvl, acceptedBits);
                }
                gs.setEntryPointsFromPreviousLayer();
            }
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    public OnHeapGraphIndex getGraph() {
        return this.graph;
    }

    public int insertsInProgress() {
        return this.insertionsInProgress.size();
    }

    @Deprecated
    public long addGraphNode(int node, RandomAccessVectorValues ravv) {
        return this.addGraphNode(node, ravv.getVector(node));
    }

    private int getRandomGraphLevel() {
        double randDouble;
        double ml;
        if (this.addHierarchy) {
            double d = ml = this.graph.getDegree(0) == 1 ? 1.0 : 1.0 / Math.log(1.0 * (double)this.graph.getDegree(0));
            while ((randDouble = this.rng.nextDouble()) == 0.0) {
            }
        } else {
            ml = 0.0;
            randDouble = 0.0;
        }
        return (int)(-Math.log(randDouble) * ml);
    }

    public long addGraphNode(int node, VectorFloat<?> vector) {
        GraphIndex.NodeAtLevel nodeLevel = new GraphIndex.NodeAtLevel(this.getRandomGraphLevel(), node);
        this.graph.addNode(nodeLevel);
        this.insertionsInProgress.add(nodeLevel);
        Object inProgressBefore = this.insertionsInProgress.clone();
        try (GraphSearcher gs = this.searchers.get();){
            SearchResult result;
            gs.setView(this.graph.getView());
            NodeArray naturalScratchPooled = this.naturalScratch.get();
            NodeArray concurrentScratchPooled = this.concurrentScratch.get();
            ExcludingBits bits = new ExcludingBits(nodeLevel.node);
            SearchScoreProvider ssp = this.scoreProvider.searchProviderFor(vector);
            GraphIndex.NodeAtLevel entry = this.graph.entry();
            if (entry == null) {
                result = new SearchResult(new SearchResult.NodeScore[0], 0, 0, 0, 0, 0.0f);
            } else {
                gs.initializeInternal(ssp, entry, bits);
                for (int lvl = entry.level; lvl > 0; --lvl) {
                    if (lvl > nodeLevel.level) {
                        gs.searchOneLayer(ssp, 1, 0.0f, lvl, gs.getView().liveNodes());
                    } else {
                        gs.searchOneLayer(ssp, this.beamWidth, 0.0f, lvl, gs.getView().liveNodes());
                        Object[] neighbors = new SearchResult.NodeScore[gs.approximateResults.size()];
                        AtomicInteger index = new AtomicInteger();
                        gs.approximateResults.foreach((arg_0, arg_1) -> GraphIndexBuilder.lambda$addGraphNode$11((SearchResult.NodeScore[])neighbors, index, arg_0, arg_1));
                        Arrays.sort(neighbors);
                        this.updateNeighborsOneLayer(lvl, nodeLevel.node, (SearchResult.NodeScore[])neighbors, naturalScratchPooled, (ConcurrentSkipListSet<GraphIndex.NodeAtLevel>)inProgressBefore, concurrentScratchPooled, ssp);
                    }
                    gs.setEntryPointsFromPreviousLayer();
                }
                result = gs.resume(this.beamWidth, this.beamWidth, 0.0f, 0.0f);
            }
            this.updateNeighborsOneLayer(0, nodeLevel.node, result.getNodes(), naturalScratchPooled, (ConcurrentSkipListSet<GraphIndex.NodeAtLevel>)inProgressBefore, concurrentScratchPooled, ssp);
            this.graph.markComplete(nodeLevel);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        finally {
            this.insertionsInProgress.remove(nodeLevel);
        }
        return IntStream.range(0, nodeLevel.level).mapToLong(this.graph::ramBytesUsedOneNode).sum();
    }

    private void updateNeighborsOneLayer(int layer, int node, SearchResult.NodeScore[] neighbors, NodeArray naturalScratchPooled, ConcurrentSkipListSet<GraphIndex.NodeAtLevel> inProgressBefore, NodeArray concurrentScratchPooled, SearchScoreProvider ssp) {
        NodeArray natural = GraphIndexBuilder.toScratchCandidates(neighbors, naturalScratchPooled);
        NodeArray concurrent = this.getConcurrentCandidates(layer, node, inProgressBefore, concurrentScratchPooled, ssp.scoreFunction());
        this.updateNeighbors(layer, node, natural, concurrent);
    }

    @VisibleForTesting
    public void setEntryPoint(int level, int node) {
        this.graph.updateEntryNode(new GraphIndex.NodeAtLevel(level, node));
    }

    public void markNodeDeleted(int node) {
        this.graph.markDeleted(node);
    }

    public synchronized long removeDeletedNodes() {
        ThreadSafeGrowableBitSet toDelete = this.graph.getDeletedNodes().copy();
        int nRemoved = toDelete.cardinality();
        if (nRemoved == 0) {
            return 0L;
        }
        IntArrayList liveNodes = new IntArrayList();
        for (int i = 0; i < this.graph.getIdUpperBound(); ++i) {
            if (!this.graph.containsNode(i) || toDelete.get(i)) continue;
            liveNodes.add(Integer.valueOf(i));
        }
        int currentLevel = 0;
        while (currentLevel < this.graph.layers.size()) {
            int level = currentLevel++;
            ConcurrentHashMap newEdges = new ConcurrentHashMap();
            ((ForkJoinTask)this.parallelExecutor.submit(() -> IntStream.range(0, this.graph.getIdUpperBound()).parallel().forEach(i -> {
                ConcurrentNeighborMap.Neighbors neighbors = this.graph.getNeighbors(level, i);
                if (neighbors == null || toDelete.get(i)) {
                    return;
                }
                NodesIterator it = neighbors.iterator();
                while (it.hasNext()) {
                    int j = it.nextInt();
                    if (!toDelete.get(j)) continue;
                    Set newEdgesForI = newEdges.computeIfAbsent(i, __ -> ConcurrentHashMap.newKeySet());
                    NodesIterator jt = this.graph.getNeighbors(level, j).iterator();
                    while (jt.hasNext()) {
                        int k = jt.nextInt();
                        if (i == k || toDelete.get(k)) continue;
                        newEdgesForI.add(k);
                    }
                }
            }))).join();
            ((ForkJoinTask)this.simdExecutor.submit(() -> ((Stream)newEdges.entrySet().stream().parallel()).forEach(e -> {
                int node = (Integer)e.getKey();
                ScoreFunction sf = this.scoreProvider.searchProviderFor(node).scoreFunction();
                NodeArray candidates = new NodeArray(this.graph.getDegree(level));
                for (Integer k : (Set)e.getValue()) {
                    candidates.insertSorted(k, sf.similarityTo(k));
                }
                if (candidates.size() == 0) {
                    ThreadLocalRandom R = ThreadLocalRandom.current();
                    for (int i = 0; i < 2 * this.graph.getDegree(level); ++i) {
                        int randomNode = liveNodes.get(R.nextInt(liveNodes.size()));
                        if (randomNode != node && !candidates.contains(randomNode) && this.graph.layers.get(level).contains(randomNode)) {
                            float score = sf.similarityTo(randomNode);
                            candidates.insertSorted(randomNode, score);
                        }
                        if (candidates.size() == this.graph.getDegree(level)) break;
                    }
                }
                this.graph.layers.get(level).replaceDeletedNeighbors(node, toDelete, candidates);
            }))).join();
        }
        if (toDelete.get(this.graph.entry().node)) {
            int newLevel;
            int newEntry = -1;
            block2: for (newLevel = this.graph.getMaxLevel(); newLevel >= 0; --newLevel) {
                NodesIterator it = this.graph.getNodes(newLevel);
                while (it.hasNext()) {
                    int i = it.nextInt();
                    if (toDelete.get(i)) continue;
                    newEntry = i;
                    break block2;
                }
            }
            this.graph.updateEntryNode(newEntry >= 0 ? new GraphIndex.NodeAtLevel(newLevel, newEntry) : null);
        }
        long memorySize = 0L;
        assert (toDelete.cardinality() == nRemoved) : "cardinality changed";
        int i = toDelete.nextSetBit(0);
        while (i != Integer.MAX_VALUE) {
            int nDeletions = this.graph.removeNode(i);
            for (int iLayer = 0; iLayer < nDeletions; ++iLayer) {
                memorySize += this.graph.ramBytesUsedOneNode(iLayer);
            }
            i = toDelete.nextSetBit(i + 1);
        }
        return memorySize;
    }

    private void updateNeighbors(int layer, int nodeId, NodeArray natural, NodeArray concurrent) {
        NodeArray toMerge = concurrent.size() == 0 ? natural : (natural.size() == 0 ? concurrent : NodeArray.merge(natural, concurrent));
        ConcurrentNeighborMap.Neighbors neighbors = this.graph.layers.get(layer).insertDiverse(nodeId, toMerge);
        this.graph.layers.get(layer).backlink(neighbors, nodeId, this.neighborOverflow);
    }

    private static NodeArray toScratchCandidates(SearchResult.NodeScore[] candidates, NodeArray scratch) {
        scratch.clear();
        for (SearchResult.NodeScore candidate : candidates) {
            scratch.addInOrder(candidate.node, candidate.score);
        }
        return scratch;
    }

    private NodeArray getConcurrentCandidates(int layer, int newNode, Set<GraphIndex.NodeAtLevel> inProgress, NodeArray scratch, ScoreFunction scoreFunction) {
        scratch.clear();
        for (GraphIndex.NodeAtLevel n : inProgress) {
            if (n.node == newNode || n.level < layer) continue;
            scratch.insertSorted(n.node, scoreFunction.similarityTo(n.node));
        }
        return scratch;
    }

    @Override
    public void close() throws IOException {
        try {
            this.searchers.close();
        }
        catch (Exception e) {
            ExceptionUtils.throwIoException(e);
        }
    }

    public void load(RandomAccessReader in) throws IOException {
        if (this.graph.size(0) != 0) {
            throw new IllegalStateException("Cannot load into a non-empty graph");
        }
        int maybeMagic = in.readInt();
        if (maybeMagic != 1978417170) {
            int version = 3;
            int size = maybeMagic;
            this.loadV3(in, size);
        } else {
            int version = in.readInt();
            this.loadV4(in);
        }
    }

    private void loadV4(RandomAccessReader in) throws IOException {
        if (this.graph.size(0) != 0) {
            throw new IllegalStateException("Cannot load into a non-empty graph");
        }
        int layerCount = in.readInt();
        int entryNode = in.readInt();
        ArrayList<Integer> layerDegrees = new ArrayList<Integer>(layerCount);
        HashMap<Integer, Integer> nodeLevelMap = new HashMap<Integer, Integer>();
        for (int level = 0; level < layerCount; ++level) {
            int layerSize = in.readInt();
            layerDegrees.add(in.readInt());
            for (int i = 0; i < layerSize; ++i) {
                int nodeId = in.readInt();
                int nNeighbors = in.readInt();
                SearchScoreProvider searchProvider = this.scoreProvider.searchProviderFor(nodeId);
                ScoreFunction sf = level > 0 || searchProvider.reranker() == null ? searchProvider.scoreFunction() : searchProvider.exactScoreFunction();
                NodeArray ca = new NodeArray(nNeighbors);
                for (int j = 0; j < nNeighbors; ++j) {
                    int neighbor = in.readInt();
                    ca.addInOrder(neighbor, sf.similarityTo(neighbor));
                }
                this.graph.addNode(level, nodeId, ca);
                nodeLevelMap.put(nodeId, level);
            }
        }
        for (Integer k : nodeLevelMap.keySet()) {
            GraphIndex.NodeAtLevel nal = new GraphIndex.NodeAtLevel((Integer)nodeLevelMap.get(k), k);
            this.graph.markComplete(nal);
        }
        this.graph.setDegrees(layerDegrees);
        this.graph.updateEntryNode(new GraphIndex.NodeAtLevel(this.graph.getMaxLevel(), entryNode));
    }

    private void loadV3(RandomAccessReader in, int size) throws IOException {
        if (this.graph.size() != 0) {
            throw new IllegalStateException("Cannot load into a non-empty graph");
        }
        int entryNode = in.readInt();
        int maxDegree = in.readInt();
        for (int i = 0; i < size; ++i) {
            int nodeId = in.readInt();
            int nNeighbors = in.readInt();
            SearchScoreProvider searchProvider = this.scoreProvider.searchProviderFor(nodeId);
            ScoreFunction sf = searchProvider.reranker() == null ? searchProvider.scoreFunction() : searchProvider.exactScoreFunction();
            NodeArray ca = new NodeArray(nNeighbors);
            for (int j = 0; j < nNeighbors; ++j) {
                int neighbor = in.readInt();
                ca.addInOrder(neighbor, sf.similarityTo(neighbor));
            }
            this.graph.addNode(0, nodeId, ca);
            this.graph.markComplete(new GraphIndex.NodeAtLevel(0, nodeId));
        }
        this.graph.updateEntryNode(new GraphIndex.NodeAtLevel(0, entryNode));
        this.graph.setDegrees(List.of(Integer.valueOf(maxDegree)));
    }

    private static /* synthetic */ void lambda$addGraphNode$11(SearchResult.NodeScore[] neighbors, AtomicInteger index, int neighbor, float score) {
        neighbors[index.getAndIncrement()] = new SearchResult.NodeScore(neighbor, score);
    }

    private /* synthetic */ void lambda$build$6(Supplier vv, int node) {
        this.addGraphNode(node, ((RandomAccessVectorValues)vv.get()).getVector(node));
    }

    private static class ExcludingBits
    implements Bits {
        private final int excluded;

        public ExcludingBits(int excluded) {
            this.excluded = excluded;
        }

        @Override
        public boolean get(int index) {
            return index != this.excluded;
        }
    }
}

