/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.annotations.Experimental;
import io.github.jbellis.jvector.graph.GraphIndex;
import io.github.jbellis.jvector.graph.NodeQueue;
import io.github.jbellis.jvector.graph.NodeSimilarity;
import io.github.jbellis.jvector.graph.NodesIterator;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.ScoreTracker;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.util.BitSet;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.BoundedLongHeap;
import io.github.jbellis.jvector.util.GrowableBitSet;
import io.github.jbellis.jvector.util.GrowableLongHeap;
import io.github.jbellis.jvector.util.SparseFixedBitSet;
import io.github.jbellis.jvector.vector.VectorEncoding;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import java.util.Arrays;
import java.util.Comparator;

public class GraphSearcher<T> {
    private final GraphIndex.View<T> view;
    private final NodeQueue candidates;
    private final BitSet visited;

    GraphSearcher(GraphIndex.View<T> view, BitSet visited) {
        this.view = view;
        this.candidates = new NodeQueue(new GrowableLongHeap(100), NodeQueue.Order.MAX_HEAP);
        this.visited = visited;
    }

    public static <T> SearchResult search(T targetVector, int topK, RandomAccessVectorValues<T> vectors, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction, GraphIndex<T> graph, Bits acceptOrds) {
        GraphSearcher<T> searcher = new Builder<T>(graph.getView()).withConcurrentUpdates().build();
        NodeSimilarity.ExactScoreFunction scoreFunction = i -> {
            switch (vectorEncoding) {
                case BYTE: {
                    return similarityFunction.compare((byte[])targetVector, (byte[])vectors.vectorValue(i));
                }
                case FLOAT32: {
                    return similarityFunction.compare((float[])targetVector, (float[])vectors.vectorValue(i));
                }
            }
            throw new RuntimeException("Unsupported vector encoding: " + String.valueOf((Object)vectorEncoding));
        };
        return searcher.search(scoreFunction, null, topK, acceptOrds);
    }

    @Experimental
    public SearchResult search(NodeSimilarity.ScoreFunction scoreFunction, NodeSimilarity.ReRanker reRanker, int topK, float threshold, Bits acceptOrds) {
        return this.searchInternal(scoreFunction, reRanker, topK, threshold, this.view.entryNode(), acceptOrds);
    }

    public SearchResult search(NodeSimilarity.ScoreFunction scoreFunction, NodeSimilarity.ReRanker reRanker, int topK, Bits acceptOrds) {
        return this.search(scoreFunction, reRanker, topK, 0.0f, acceptOrds);
    }

    SearchResult searchInternal(NodeSimilarity.ScoreFunction scoreFunction, NodeSimilarity.ReRanker reRanker, int topK, float threshold, int ep, Bits acceptOrds) {
        float topCandidateScore;
        ScoreTracker scoreTracker;
        if (!scoreFunction.isExact() && reRanker == null) {
            throw new IllegalArgumentException("Either scoreFunction must be exact, or reRanker must not be null");
        }
        if (acceptOrds == null) {
            throw new IllegalArgumentException("Use MatchAllBits to indicate that all ordinals are accepted, instead of null");
        }
        this.prepareScratchState(this.view.size());
        ScoreTracker scoreTracker2 = scoreTracker = threshold > 0.0f ? new ScoreTracker.NormalDistributionTracker(threshold) : ScoreTracker.NO_OP;
        if (ep < 0) {
            return new SearchResult(new SearchResult.NodeScore[0], this.visited, 0);
        }
        acceptOrds = Bits.intersectionOf(acceptOrds, this.view.liveNodes());
        NodeQueue resultsQueue = new NodeQueue(new BoundedLongHeap(Math.min(1024, topK), topK), NodeQueue.Order.MIN_HEAP);
        int numVisited = 0;
        float score = scoreFunction.similarityTo(ep);
        this.visited.set(ep);
        ++numVisited;
        this.candidates.push(ep, score);
        float minAcceptedSimilarity = Float.NEGATIVE_INFINITY;
        while (!(this.candidates.size() <= 0 || resultsQueue.incomplete() || (topCandidateScore = this.candidates.topScore()) < minAcceptedSimilarity || scoreTracker.shouldStop(numVisited))) {
            int topCandidateNode = this.candidates.pop();
            if (acceptOrds.get(topCandidateNode) && topCandidateScore >= threshold && resultsQueue.push(topCandidateNode, topCandidateScore) && resultsQueue.size() >= topK) {
                minAcceptedSimilarity = resultsQueue.topScore();
            }
            NodesIterator it = this.view.getNeighborsIterator(topCandidateNode);
            while (it.hasNext()) {
                int friendOrd = it.nextInt();
                if (this.visited.getAndSet(friendOrd)) continue;
                ++numVisited;
                float friendSimilarity = scoreFunction.similarityTo(friendOrd);
                scoreTracker.track(friendSimilarity);
                if (!(friendSimilarity >= minAcceptedSimilarity)) continue;
                this.candidates.push(friendOrd, friendSimilarity);
            }
        }
        assert (resultsQueue.size() <= topK);
        SearchResult.NodeScore[] nodes = GraphSearcher.extractScores(scoreFunction, reRanker, resultsQueue);
        return new SearchResult(nodes, this.visited, numVisited);
    }

    private static SearchResult.NodeScore[] extractScores(NodeSimilarity.ScoreFunction sf, NodeSimilarity.ReRanker reRanker, NodeQueue resultsQueue) {
        SearchResult.NodeScore[] nodes;
        if (sf.isExact()) {
            nodes = new SearchResult.NodeScore[resultsQueue.size()];
            for (int i = nodes.length - 1; i >= 0; --i) {
                float nScore = resultsQueue.topScore();
                int n = resultsQueue.pop();
                nodes[i] = new SearchResult.NodeScore(n, nScore);
            }
        } else {
            nodes = resultsQueue.nodesCopy(reRanker::similarityTo);
            Arrays.sort(nodes, 0, resultsQueue.size(), Comparator.comparingDouble(nodeScore -> nodeScore.score).reversed());
        }
        return nodes;
    }

    private void prepareScratchState(int capacity) {
        this.candidates.clear();
        if (this.visited.length() < capacity && !(this.visited instanceof GrowableBitSet)) {
            throw new IllegalArgumentException(String.format("Unexpected visited type: %s. Encountering this means that the graph changed while being searched, and the Searcher was not built withConcurrentUpdates()", this.visited.getClass().getName()));
        }
        this.visited.clear();
    }

    public static class Builder<T> {
        private final GraphIndex.View<T> view;
        private boolean concurrent;

        public Builder(GraphIndex.View<T> view) {
            this.view = view;
        }

        public Builder<T> withConcurrentUpdates() {
            this.concurrent = true;
            return this;
        }

        public GraphSearcher<T> build() {
            int size = this.view.getIdUpperBound();
            BitSet bits = this.concurrent ? new GrowableBitSet(size) : new SparseFixedBitSet(size);
            return new GraphSearcher<T>(this.view, bits);
        }
    }
}

