/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.index.sai.disk.v1.vector;

import io.github.jbellis.jvector.disk.CachingGraphIndex;
import io.github.jbellis.jvector.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.GraphIndex;
import io.github.jbellis.jvector.graph.GraphSearcher;
import io.github.jbellis.jvector.graph.NeighborSimilarity;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.pq.CompressedVectors;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.PrimitiveIterator;
import java.util.stream.IntStream;
import org.apache.cassandra.index.sai.disk.format.IndexComponent;
import org.apache.cassandra.index.sai.disk.v1.IndexWriterConfig;
import org.apache.cassandra.index.sai.disk.v1.PerColumnIndexFiles;
import org.apache.cassandra.index.sai.disk.v1.postings.VectorPostingList;
import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata;
import org.apache.cassandra.index.sai.disk.v1.vector.OnDiskOrdinalsMap;
import org.apache.cassandra.index.sai.disk.v1.vector.OnHeapGraph;
import org.apache.cassandra.index.sai.disk.v1.vector.RandomAccessReaderAdapter;
import org.apache.cassandra.io.util.FileHandle;
import org.apache.cassandra.tracing.Tracing;

public class DiskAnn
implements AutoCloseable {
    private final FileHandle graphHandle;
    private final OnDiskOrdinalsMap ordinalsMap;
    private final CachingGraphIndex graph;
    private final VectorSimilarityFunction similarityFunction;
    private final CompressedVectors compressedVectors;

    public DiskAnn(SegmentMetadata.ComponentMetadataMap componentMetadatas, PerColumnIndexFiles indexFiles, IndexWriterConfig config) throws IOException {
        this.similarityFunction = config.getSimilarityFunction();
        SegmentMetadata.ComponentMetadata termsMetadata = componentMetadatas.get(IndexComponent.TERMS_DATA);
        this.graphHandle = indexFiles.termsData();
        this.graph = new CachingGraphIndex(new OnDiskGraphIndex(RandomAccessReaderAdapter.createSupplier(this.graphHandle), termsMetadata.offset));
        long pqSegmentOffset = componentMetadatas.get((IndexComponent)IndexComponent.COMPRESSED_VECTORS).offset;
        try (FileHandle pqFileHandle = indexFiles.compressedVectors();
             RandomAccessReaderAdapter reader = new RandomAccessReaderAdapter(pqFileHandle);){
            reader.seek(pqSegmentOffset);
            boolean containsCompressedVectors = reader.readBoolean();
            this.compressedVectors = containsCompressedVectors ? CompressedVectors.load((RandomAccessReader)reader, (long)reader.getFilePointer()) : null;
        }
        SegmentMetadata.ComponentMetadata postingListsMetadata = componentMetadatas.get(IndexComponent.POSTING_LISTS);
        this.ordinalsMap = new OnDiskOrdinalsMap(indexFiles.postingLists(), postingListsMetadata.offset, postingListsMetadata.length);
    }

    public long ramBytesUsed() {
        return this.graph.ramBytesUsed();
    }

    public int size() {
        return this.graph.size();
    }

    public VectorPostingList search(float[] queryVector, int topK, int limit, Bits acceptBits) {
        NeighborSimilarity.ReRanker reRanker;
        NeighborSimilarity.ExactScoreFunction scoreFunction;
        OnHeapGraph.validateIndexable(queryVector, this.similarityFunction);
        GraphIndex.View view = this.graph.getView();
        GraphSearcher searcher = new GraphSearcher.Builder(view).build();
        if (this.compressedVectors == null) {
            scoreFunction = i -> this.similarityFunction.compare(queryVector, (float[])view.getVector(i));
            reRanker = null;
        } else {
            scoreFunction = this.compressedVectors.approximateScoreFunctionFor(queryVector, this.similarityFunction);
            reRanker = (i, map) -> this.similarityFunction.compare(queryVector, (float[])map.get(i));
        }
        SearchResult result = searcher.search((NeighborSimilarity.ScoreFunction)scoreFunction, reRanker, topK, this.ordinalsMap.ignoringDeleted(acceptBits));
        Tracing.trace("DiskANN search visited {} nodes to return {} results", (Object)result.getVisitedCount(), (Object)result.getNodes().length);
        return this.annRowIdsToPostings(result, limit);
    }

    private VectorPostingList annRowIdsToPostings(SearchResult results, int limit) {
        try (RowIdIterator iterator = new RowIdIterator(results.getNodes());){
            VectorPostingList vectorPostingList = new VectorPostingList(iterator, limit, results.getVisitedCount());
            return vectorPostingList;
        }
    }

    @Override
    public void close() throws IOException {
        this.ordinalsMap.close();
        this.graph.close();
        this.graphHandle.close();
    }

    public OnDiskOrdinalsMap.OrdinalsView getOrdinalsView() {
        return this.ordinalsMap.getOrdinalsView();
    }

    private class RowIdIterator
    implements PrimitiveIterator.OfInt,
    AutoCloseable {
        private final Iterator<SearchResult.NodeScore> it;
        private final OnDiskOrdinalsMap.RowIdsView rowIdsView;
        private PrimitiveIterator.OfInt segmentRowIdIterator;

        public RowIdIterator(SearchResult.NodeScore[] results) {
            this.rowIdsView = DiskAnn.this.ordinalsMap.getRowIdsView();
            this.segmentRowIdIterator = IntStream.empty().iterator();
            this.it = Arrays.stream(results).iterator();
        }

        @Override
        public boolean hasNext() {
            while (!this.segmentRowIdIterator.hasNext() && this.it.hasNext()) {
                try {
                    int ordinal = this.it.next().node;
                    this.segmentRowIdIterator = Arrays.stream(this.rowIdsView.getSegmentRowIdsMatching(ordinal)).iterator();
                }
                catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
            return this.segmentRowIdIterator.hasNext();
        }

        @Override
        public int nextInt() {
            if (!this.hasNext()) {
                throw new NoSuchElementException();
            }
            return this.segmentRowIdIterator.nextInt();
        }

        @Override
        public void close() {
            this.rowIdsView.close();
        }
    }
}

