/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.graph.disk.feature;

import io.github.jbellis.jvector.disk.IndexWriter;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.ImmutableGraphIndex;
import io.github.jbellis.jvector.graph.NodesIterator;
import io.github.jbellis.jvector.graph.disk.CommonHeader;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.graph.disk.feature.AbstractFeature;
import io.github.jbellis.jvector.graph.disk.feature.Feature;
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
import io.github.jbellis.jvector.graph.disk.feature.FusedFeature;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.quantization.FusedPQDecoder;
import io.github.jbellis.jvector.quantization.PQVectors;
import io.github.jbellis.jvector.quantization.ProductQuantization;
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.function.IntFunction;
import org.agrona.collections.Int2ObjectHashMap;

public class FusedPQ
extends AbstractFeature
implements FusedFeature {
    private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
    private final ProductQuantization pq;
    private final int maxDegree;
    private final ThreadLocal<VectorFloat<?>> reusableResults;
    private final ExplicitThreadLocal<ByteSequence<?>> reusableNeighborCodes;
    private final ExplicitThreadLocal<ByteSequence<?>> pqCodeScratch;

    public FusedPQ(int maxDegree, ProductQuantization pq) {
        if (pq.getClusterCount() != 256) {
            throw new IllegalArgumentException("FusedPQ requires a 256-cluster PQ. This limitation may be removed in future releases");
        }
        this.maxDegree = maxDegree;
        this.pq = pq;
        this.reusableResults = ThreadLocal.withInitial(() -> vectorTypeSupport.createFloatVector(maxDegree));
        this.reusableNeighborCodes = ExplicitThreadLocal.withInitial(() -> vectorTypeSupport.createByteSequence(pq.compressedVectorSize() * maxDegree));
        this.pqCodeScratch = ExplicitThreadLocal.withInitial(() -> vectorTypeSupport.createByteSequence(pq.compressedVectorSize()));
    }

    @Override
    public FeatureId id() {
        return FeatureId.FUSED_PQ;
    }

    public ProductQuantization getPQ() {
        return this.pq;
    }

    @Override
    public int headerSize() {
        return this.pq.compressorSize();
    }

    @Override
    public int featureSize() {
        return this.pq.compressedVectorSize() * this.maxDegree;
    }

    static FusedPQ load(CommonHeader header, RandomAccessReader reader) {
        try {
            return new FusedPQ(header.layerInfo.get((int)0).degree, ProductQuantization.load(reader));
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    public ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat<?> queryVector, VectorSimilarityFunction vsf, OnDiskGraphIndex.View view, ScoreFunction.ExactScoreFunction esf) {
        PackedNeighbors neighbors = new PackedNeighbors(view);
        Int2ObjectHashMap<FusedFeature.InlineSource> hierarchyCachedFeatures = view.getInlineSourceFeatures();
        return FusedPQDecoder.newDecoder(neighbors, this.pq, hierarchyCachedFeatures, queryVector, this.reusableNeighborCodes.get(), this.reusableResults.get(), vsf, esf);
    }

    @Override
    public void writeHeader(IndexWriter out) throws IOException {
        this.pq.write(out, 6);
    }

    @Override
    public void writeInline(IndexWriter out, Feature.State state_) throws IOException {
        State state = (State)state_;
        NodesIterator neighbors = state.view.getNeighborsIterator(0, state.nodeId);
        int count = 0;
        while (neighbors.hasNext()) {
            int node = neighbors.nextInt();
            ByteSequence<?> compressed = state.compressedVectorFunction.apply(node);
            vectorTypeSupport.writeByteSequence(out, compressed.copy());
            ++count;
        }
        this.pqCodeScratch.get().zero();
        while (count < this.maxDegree) {
            vectorTypeSupport.writeByteSequence(out, this.pqCodeScratch.get());
            ++count;
        }
    }

    @Override
    public void writeSourceFeature(IndexWriter out, Feature.State state_) throws IOException {
        State state = (State)state_;
        ByteSequence<?> compressed = state.compressedVectorFunction.apply(state.nodeId);
        ByteSequence<?> temp = this.pqCodeScratch.get();
        for (int i = 0; i < compressed.length(); ++i) {
            temp.set(i, compressed.get(i));
        }
        vectorTypeSupport.writeByteSequence(out, temp);
    }

    @Override
    public FusedFeature.InlineSource loadSourceFeature(RandomAccessReader in) throws IOException {
        int length = this.pq.getSubspaceCount();
        ByteSequence<?> code = vectorTypeSupport.createByteSequence(length);
        vectorTypeSupport.readByteSequence(in, code);
        return new FusedPQInlineSource(code);
    }

    public class PackedNeighbors {
        private final OnDiskGraphIndex.View view;

        public PackedNeighbors(OnDiskGraphIndex.View view) {
            this.view = view;
        }

        public void readInto(int node, ByteSequence<?> neighborCodes) {
            try {
                this.view.getPackedNeighbors(node, FeatureId.FUSED_PQ, reader -> {
                    try {
                        vectorTypeSupport.readByteSequence((RandomAccessReader)reader, neighborCodes);
                    }
                    catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                });
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        public int maxDegree() {
            return FusedPQ.this.maxDegree;
        }
    }

    public static class State
    implements Feature.State {
        public final ImmutableGraphIndex.View view;
        public final IntFunction<ByteSequence<?>> compressedVectorFunction;
        public final int nodeId;

        public State(ImmutableGraphIndex.View view, PQVectors pqVectors, int nodeId) {
            this(view, pqVectors::get, nodeId);
        }

        public State(ImmutableGraphIndex.View view, IntFunction<ByteSequence<?>> compressedVectorFunction, int nodeId) {
            this.view = view;
            this.compressedVectorFunction = compressedVectorFunction;
            this.nodeId = nodeId;
        }
    }

    public static class FusedPQInlineSource
    implements FusedFeature.InlineSource {
        private ByteSequence<?> code;

        public FusedPQInlineSource(ByteSequence<?> code) {
            this.code = code;
        }

        @Override
        public long ramBytesUsed() {
            return this.code.length();
        }

        public ByteSequence<?> getCode() {
            return this.code;
        }
    }
}

