/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.annotations.Experimental;
import io.github.jbellis.jvector.graph.ImmutableGraphIndex;
import io.github.jbellis.jvector.graph.NodeQueue;
import io.github.jbellis.jvector.graph.NodesIterator;
import io.github.jbellis.jvector.graph.NodesUnsorted;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.ScoreTracker;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.BoundedLongHeap;
import io.github.jbellis.jvector.util.GrowableLongHeap;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import java.io.Closeable;
import java.io.IOException;
import org.agrona.collections.Int2ObjectHashMap;
import org.agrona.collections.IntHashSet;

public class GraphSearcher
implements Closeable {
    private ImmutableGraphIndex.View view;
    private final NodeQueue candidates;
    final NodeQueue approximateResults;
    private final NodeQueue rerankedResults;
    private final IntHashSet visited;
    private final NodesUnsorted evictedResults;
    private Bits acceptOrds;
    private SearchScoreProvider scoreProvider;
    private CachingReranker cachingReranker;
    private boolean pruneSearch;
    private final ScoreTracker.ScoreTrackerFactory scoreTrackerFactory;
    private int visitedCount;
    private int expandedCount;
    private int expandedCountBaseLayer;

    public GraphSearcher(ImmutableGraphIndex graph) {
        this(graph.getView());
    }

    protected GraphSearcher(ImmutableGraphIndex.View view) {
        this.view = view;
        this.candidates = new NodeQueue(new GrowableLongHeap(100), NodeQueue.Order.MAX_HEAP);
        this.evictedResults = new NodesUnsorted(100);
        this.approximateResults = new NodeQueue(new BoundedLongHeap(100), NodeQueue.Order.MIN_HEAP);
        this.rerankedResults = new NodeQueue(new BoundedLongHeap(100), NodeQueue.Order.MIN_HEAP);
        this.visited = new IntHashSet();
        this.pruneSearch = true;
        this.scoreTrackerFactory = new ScoreTracker.ScoreTrackerFactory();
    }

    protected int getVisitedCount() {
        return this.visitedCount;
    }

    protected int getExpandedCount() {
        return this.expandedCount;
    }

    protected int getExpandedCountBaseLayer() {
        return this.expandedCountBaseLayer;
    }

    private void initializeScoreProvider(SearchScoreProvider scoreProvider) {
        this.scoreProvider = scoreProvider;
        if (scoreProvider.reranker() == null) {
            this.cachingReranker = null;
            return;
        }
        this.cachingReranker = new CachingReranker(scoreProvider);
    }

    public ImmutableGraphIndex.View getView() {
        return this.view;
    }

    public void usePruning(boolean usage) {
        this.pruneSearch = usage;
    }

    public static SearchResult search(VectorFloat<?> queryVector, int topK, RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction, ImmutableGraphIndex graph, Bits acceptOrds) {
        SearchResult searchResult;
        GraphSearcher searcher = new GraphSearcher(graph);
        try {
            DefaultSearchScoreProvider ssp = DefaultSearchScoreProvider.exact(queryVector, similarityFunction, vectors);
            searchResult = searcher.search(ssp, topK, acceptOrds);
        }
        catch (Throwable throwable) {
            try {
                try {
                    searcher.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        searcher.close();
        return searchResult;
    }

    public static SearchResult search(VectorFloat<?> queryVector, int topK, int rerankK, RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction, ImmutableGraphIndex graph, Bits acceptOrds) {
        SearchResult searchResult;
        GraphSearcher searcher = new GraphSearcher(graph);
        try {
            DefaultSearchScoreProvider ssp = DefaultSearchScoreProvider.exact(queryVector, similarityFunction, vectors);
            searchResult = searcher.search(ssp, topK, rerankK, 0.0f, 0.0f, acceptOrds);
        }
        catch (Throwable throwable) {
            try {
                try {
                    searcher.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        searcher.close();
        return searchResult;
    }

    public void setView(ImmutableGraphIndex.View view) {
        this.view = view;
    }

    @Experimental
    public SearchResult search(SearchScoreProvider scoreProvider, int topK, int rerankK, float threshold, float rerankFloor, Bits acceptOrds) {
        ImmutableGraphIndex.NodeAtLevel entry = this.view.entryNode();
        if (acceptOrds == null) {
            throw new IllegalArgumentException("Use MatchAllBits to indicate that all ordinals are accepted, instead of null");
        }
        if (rerankK < topK) {
            throw new IllegalArgumentException(String.format("rerankK %d must be >= topK %d", rerankK, topK));
        }
        if (entry == null) {
            return new SearchResult(new SearchResult.NodeScore[0], 0, 0, 0, 0, Float.POSITIVE_INFINITY);
        }
        this.internalSearch(scoreProvider, entry, topK, rerankK, threshold, acceptOrds);
        return this.reranking(topK, rerankK, rerankFloor);
    }

    protected void internalSearch(SearchScoreProvider scoreProvider, ImmutableGraphIndex.NodeAtLevel entry, int topK, int rerankK, float threshold, Bits acceptOrds) {
        this.initializeInternal(scoreProvider, entry, acceptOrds);
        for (int lvl = entry.level; lvl > 0; --lvl) {
            this.searchOneLayer(scoreProvider, 1, 0.0f, lvl, Bits.ALL);
            assert (this.approximateResults.size() == 1) : this.approximateResults.size();
            this.setEntryPointsFromPreviousLayer();
        }
        this.searchLayer0(topK, rerankK, threshold);
    }

    public SearchResult search(SearchScoreProvider scoreProvider, int topK, float threshold, Bits acceptOrds) {
        return this.search(scoreProvider, topK, topK, threshold, 0.0f, acceptOrds);
    }

    public SearchResult search(SearchScoreProvider scoreProvider, int topK, Bits acceptOrds) {
        return this.search(scoreProvider, topK, 0.0f, acceptOrds);
    }

    void setEntryPointsFromPreviousLayer() {
        this.approximateResults.foreach(this.candidates::push);
        this.evictedResults.foreach(this.candidates::push);
        this.evictedResults.clear();
        this.approximateResults.clear();
    }

    void initializeInternal(SearchScoreProvider scoreProvider, ImmutableGraphIndex.NodeAtLevel entry, Bits rawAcceptOrds) {
        this.initializeScoreProvider(scoreProvider);
        this.acceptOrds = Bits.intersectionOf(rawAcceptOrds, this.view.liveNodes());
        this.approximateResults.clear();
        this.evictedResults.clear();
        this.candidates.clear();
        this.visited.clear();
        float score = scoreProvider.scoreFunction().similarityTo(entry.node);
        this.visited.add(entry.node);
        this.candidates.push(entry.node, score);
        this.visitedCount = 0;
        this.expandedCount = 0;
        this.expandedCountBaseLayer = 0;
    }

    private boolean stopSearch(NodeQueue localCandidates, ScoreTracker scoreTracker, int rerankK, float threshold) {
        float topCandidateScore = localCandidates.topScore();
        if (this.approximateResults.size() >= rerankK && topCandidateScore < this.approximateResults.topScore()) {
            return true;
        }
        return threshold > 0.0f && scoreTracker.shouldStop();
    }

    void searchOneLayer(SearchScoreProvider scoreProvider, int rerankK, float threshold, int level, Bits acceptOrdsThisLayer) {
        try {
            assert (this.approximateResults.size() == 0);
            this.approximateResults.setMaxSize(rerankK);
            ScoreTracker scoreTracker = this.scoreTrackerFactory.getScoreTracker(this.pruneSearch, rerankK, threshold);
            VectorFloat<?> similarities = null;
            while (this.candidates.size() > 0 && !this.stopSearch(this.candidates, scoreTracker, rerankK, threshold)) {
                float topCandidateScore = this.candidates.topScore();
                int topCandidateNode = this.candidates.pop();
                if (acceptOrdsThisLayer.get(topCandidateNode) && topCandidateScore >= threshold) {
                    this.addTopCandidate(topCandidateNode, topCandidateScore, rerankK);
                }
                if (scoreTracker.shouldStop() && this.candidates.size() >= rerankK - this.approximateResults.size()) continue;
                if (level == 0) {
                    ++this.expandedCountBaseLayer;
                }
                ++this.expandedCount;
                ScoreFunction scoreFunction = scoreProvider.scoreFunction();
                boolean useEdgeLoading = scoreFunction.supportsEdgeLoadingSimilarity();
                if (useEdgeLoading) {
                    similarities = scoreFunction.edgeLoadingSimilarityTo(topCandidateNode);
                }
                int i = 0;
                NodesIterator it = this.view.getNeighborsIterator(level, topCandidateNode);
                while (it.hasNext()) {
                    int friendOrd = it.nextInt();
                    if (!this.visited.add(friendOrd)) continue;
                    ++this.visitedCount;
                    float friendSimilarity = useEdgeLoading ? similarities.get(i) : scoreFunction.similarityTo(friendOrd);
                    scoreTracker.track(friendSimilarity);
                    this.candidates.push(friendOrd, friendSimilarity);
                    ++i;
                }
            }
        }
        catch (Throwable t) {
            this.approximateResults.clear();
            throw t;
        }
    }

    private void searchLayer0(int topK, int rerankK, float threshold) {
        this.rerankedResults.clear();
        this.rerankedResults.setMaxSize(topK);
        this.evictedResults.foreach(this.candidates::push);
        this.evictedResults.clear();
        this.searchOneLayer(this.scoreProvider, rerankK, threshold, 0, this.acceptOrds);
    }

    SearchResult reranking(int topK, int rerankK, float rerankFloor) {
        NodeQueue popFromQueue;
        float worstApproximateInTopK;
        int reranked;
        assert (this.approximateResults.size() <= rerankK);
        if (this.cachingReranker == null) {
            while (this.approximateResults.size() > topK) {
                float nScore = this.approximateResults.topScore();
                int n = this.approximateResults.pop();
                this.evictedResults.add(n, nScore);
            }
            reranked = 0;
            worstApproximateInTopK = Float.POSITIVE_INFINITY;
            popFromQueue = this.approximateResults;
        } else {
            int oldReranked = this.cachingReranker.getRerankCalls();
            worstApproximateInTopK = this.approximateResults.rerank(topK, this.cachingReranker, rerankFloor, this.rerankedResults, this.evictedResults);
            reranked = this.cachingReranker.getRerankCalls() - oldReranked;
            this.approximateResults.clear();
            popFromQueue = this.rerankedResults;
        }
        assert (popFromQueue.size() <= topK);
        SearchResult.NodeScore[] nodes = new SearchResult.NodeScore[popFromQueue.size()];
        for (int i = nodes.length - 1; i >= 0; --i) {
            float nScore = popFromQueue.topScore();
            int n = popFromQueue.pop();
            nodes[i] = new SearchResult.NodeScore(n, nScore);
        }
        assert (popFromQueue.size() == 0);
        return new SearchResult(nodes, this.visitedCount, this.expandedCount, this.expandedCountBaseLayer, reranked, worstApproximateInTopK);
    }

    SearchResult resume(int topK, int rerankK, float threshold, float rerankFloor) {
        this.searchLayer0(topK, rerankK, threshold);
        return this.reranking(topK, rerankK, rerankFloor);
    }

    private void addTopCandidate(int topCandidateNode, float topCandidateScore, int rerankK) {
        if (this.approximateResults.size() < rerankK) {
            this.approximateResults.push(topCandidateNode, topCandidateScore);
        } else if (topCandidateScore > this.approximateResults.topScore()) {
            int evictedNode = this.approximateResults.topNode();
            float evictedScore = this.approximateResults.topScore();
            this.evictedResults.add(evictedNode, evictedScore);
            this.approximateResults.push(topCandidateNode, topCandidateScore);
        }
    }

    @Experimental
    public SearchResult resume(int additionalK, int rerankK) {
        this.visitedCount = 0;
        this.expandedCount = 0;
        this.expandedCountBaseLayer = 0;
        return this.resume(additionalK, rerankK, 0.0f, 0.0f);
    }

    @Override
    public void close() throws IOException {
        this.view.close();
    }

    private static class CachingReranker
    implements ScoreFunction.ExactScoreFunction {
        private final Int2ObjectHashMap<Float> cachedScores;
        private final SearchScoreProvider scoreProvider;
        private int rerankCalls;

        public CachingReranker(SearchScoreProvider scoreProvider) {
            this.scoreProvider = scoreProvider;
            this.cachedScores = new Int2ObjectHashMap();
            this.rerankCalls = 0;
        }

        @Override
        public float similarityTo(int node2) {
            if (this.cachedScores.containsKey(node2)) {
                return ((Float)this.cachedScores.get(node2)).floatValue();
            }
            ++this.rerankCalls;
            float score = this.scoreProvider.reranker().similarityTo(node2);
            this.cachedScores.put(node2, (Object)Float.valueOf(score));
            return score;
        }

        public int getRerankCalls() {
            return this.rerankCalls;
        }
    }

    @Deprecated
    public static class Builder {
        private final ImmutableGraphIndex.View view;

        public Builder(ImmutableGraphIndex.View view) {
            this.view = view;
        }

        public Builder withConcurrentUpdates() {
            return this;
        }

        public GraphSearcher build() {
            return new GraphSearcher(this.view);
        }
    }
}

