/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.graph.disk.FusedADC;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.pq.ProductQuantization;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;

public abstract class QuickADCPQDecoder
implements ScoreFunction.ApproximateScoreFunction {
    protected final ProductQuantization pq;
    protected final VectorFloat<?> query;
    protected final ScoreFunction.ExactScoreFunction esf;

    protected QuickADCPQDecoder(ProductQuantization pq, VectorFloat<?> query, ScoreFunction.ExactScoreFunction esf) {
        this.pq = pq;
        this.query = query;
        this.esf = esf;
    }

    public static QuickADCPQDecoder newDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat<?> query, VectorFloat<?> results, VectorSimilarityFunction similarityFunction, ScoreFunction.ExactScoreFunction esf) {
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return new DotProductDecoder(neighbors, pq, query, results, esf);
            }
            case EUCLIDEAN: {
                return new EuclideanDecoder(neighbors, pq, query, results, esf);
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf((Object)similarityFunction));
    }

    static class DotProductDecoder
    extends CachingDecoder {
        public DotProductDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat<?> query, VectorFloat<?> results, ScoreFunction.ExactScoreFunction esf) {
            super(neighbors, results, pq, query, neighbors.maxDegree(), VectorSimilarityFunction.DOT_PRODUCT, esf);
            this.worstDistance = Float.MAX_VALUE;
        }

        @Override
        protected float distanceToScore(float distance) {
            return (distance + 1.0f) / 2.0f;
        }

        @Override
        protected void updateWorstDistance(float distance) {
            this.worstDistance = Math.min(this.worstDistance, distance);
        }
    }

    static class EuclideanDecoder
    extends CachingDecoder {
        public EuclideanDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat<?> query, VectorFloat<?> results, ScoreFunction.ExactScoreFunction esf) {
            super(neighbors, results, pq, query, neighbors.maxDegree(), VectorSimilarityFunction.EUCLIDEAN, esf);
            this.worstDistance = Float.MIN_VALUE;
        }

        @Override
        protected float distanceToScore(float distance) {
            return 1.0f / (1.0f + distance);
        }

        @Override
        protected void updateWorstDistance(float distance) {
            this.worstDistance = Math.max(this.worstDistance, distance);
        }
    }

    protected static abstract class CachingDecoder
    extends QuickADCPQDecoder {
        protected final FusedADC.PackedNeighbors neighbors;
        protected final VectorFloat<?> results;
        protected final VectorFloat<?> partialSums;
        protected final ByteSequence<?> partialQuantizedSums;
        protected final VectorFloat<?> partialBestDistances;
        private final VectorSimilarityFunction vsf;
        protected final float bestDistance;
        protected final int invocationThreshold;
        protected float worstDistance;
        protected int invocations;
        protected boolean supportsQuantizedSimilarity;
        protected float delta;

        protected CachingDecoder(FusedADC.PackedNeighbors neighbors, VectorFloat<?> results, ProductQuantization pq, VectorFloat<?> query, int invocationThreshold, VectorSimilarityFunction vsf, ScoreFunction.ExactScoreFunction esf) {
            super(pq, query, esf);
            this.neighbors = neighbors;
            this.results = results;
            this.vsf = vsf;
            this.invocationThreshold = invocationThreshold;
            this.partialSums = pq.reusablePartialSums();
            this.partialBestDistances = pq.reusablePartialBestDistances();
            VectorFloat<?> center = pq.globalCentroid;
            VectorFloat<?> centeredQuery = center == null ? query : VectorUtil.sub(query, center);
            for (int i = 0; i < pq.getSubspaceCount(); ++i) {
                int offset = pq.subvectorSizesAndOffsets[i][1];
                int size = pq.subvectorSizesAndOffsets[i][0];
                VectorFloat<?> codebook = pq.codebooks[i];
                VectorUtil.calculatePartialSums(codebook, i, size, pq.getClusterCount(), centeredQuery, offset, vsf, this.partialSums, this.partialBestDistances);
            }
            this.bestDistance = VectorUtil.sum(this.partialBestDistances);
            this.partialQuantizedSums = pq.reusablePartialQuantizedSums();
            this.delta = 0.0f;
            this.worstDistance = 0.0f;
            this.invocations = 0;
            this.supportsQuantizedSimilarity = false;
        }

        @Override
        public VectorFloat<?> edgeLoadingSimilarityTo(int origin) {
            int i;
            ByteSequence<?> permutedNodes = this.neighbors.getPackedNeighbors(origin);
            this.results.zero();
            if (this.supportsQuantizedSimilarity) {
                VectorUtil.bulkShuffleQuantizedSimilarity(permutedNodes, this.pq.compressedVectorSize(), this.partialQuantizedSums, this.delta, this.bestDistance, this.results, this.vsf);
                return this.results;
            }
            int nodeCount = this.results.length();
            for (i = 0; i < this.pq.getSubspaceCount(); ++i) {
                for (int j = 0; j < nodeCount; ++j) {
                    this.results.set(j, this.results.get(j) + this.partialSums.get(i * this.pq.getClusterCount() + Byte.toUnsignedInt(permutedNodes.get(i * nodeCount + j))));
                }
            }
            for (i = 0; i < nodeCount; ++i) {
                float result = this.results.get(i);
                ++this.invocations;
                this.updateWorstDistance(result);
                this.results.set(i, this.distanceToScore(result));
            }
            if (this.invocations >= this.invocationThreshold) {
                this.delta = (this.worstDistance - this.bestDistance) / 65535.0f;
                VectorUtil.quantizePartialSums(this.delta, this.partialSums, this.partialBestDistances, this.partialQuantizedSums);
                this.supportsQuantizedSimilarity = true;
            }
            return this.results;
        }

        @Override
        public boolean supportsEdgeLoadingSimilarity() {
            return true;
        }

        @Override
        public float similarityTo(int node2) {
            return this.esf.similarityTo(node2);
        }

        protected abstract float distanceToScore(float var1);

        protected abstract void updateWorstDistance(float var1);
    }
}

