/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.quantization;

import io.github.jbellis.jvector.graph.disk.FusedADC;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.quantization.ProductQuantization;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.util.Arrays;

public abstract class FusedADCPQDecoder
implements ScoreFunction.ApproximateScoreFunction {
    private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
    protected final ProductQuantization pq;
    protected final VectorFloat<?> query;
    protected final ScoreFunction.ExactScoreFunction esf;
    protected final ByteSequence<?> partialQuantizedSums;
    protected final FusedADC.PackedNeighbors neighbors;
    protected final VectorFloat<?> results;
    protected final VectorFloat<?> partialSums;
    protected final VectorFloat<?> partialBestDistances;
    protected final int invocationThreshold;
    protected int invocations = 0;
    protected float bestDistance;
    protected float worstDistance;
    protected float delta;
    protected boolean supportsQuantizedSimilarity = false;
    protected final VectorSimilarityFunction vsf;

    protected FusedADCPQDecoder(ProductQuantization pq, VectorFloat<?> query, int invocationThreshold, FusedADC.PackedNeighbors neighbors, VectorFloat<?> results, ScoreFunction.ExactScoreFunction esf, VectorSimilarityFunction vsf) {
        this.pq = pq;
        this.query = query;
        this.esf = esf;
        this.invocationThreshold = invocationThreshold;
        this.neighbors = neighbors;
        this.results = results;
        this.vsf = vsf;
        this.partialSums = pq.reusablePartialSums();
        this.partialBestDistances = pq.reusablePartialBestDistances();
        if (vsf != VectorSimilarityFunction.COSINE) {
            VectorFloat<?> center = pq.globalCentroid;
            VectorFloat<?> centeredQuery = center == null ? query : VectorUtil.sub(query, center);
            for (int i = 0; i < pq.getSubspaceCount(); ++i) {
                int offset = pq.subvectorSizesAndOffsets[i][1];
                int size = pq.subvectorSizesAndOffsets[i][0];
                VectorFloat<?> codebook = pq.codebooks[i];
                VectorUtil.calculatePartialSums(codebook, i, size, pq.getClusterCount(), centeredQuery, offset, vsf, this.partialSums, this.partialBestDistances);
            }
            this.bestDistance = VectorUtil.sum(this.partialBestDistances);
        }
        this.partialQuantizedSums = pq.reusablePartialQuantizedSums();
    }

    @Override
    public VectorFloat<?> edgeLoadingSimilarityTo(int origin) {
        int i;
        ByteSequence<?> permutedNodes = this.neighbors.getPackedNeighbors(origin);
        this.results.zero();
        if (this.supportsQuantizedSimilarity) {
            VectorUtil.bulkShuffleQuantizedSimilarity(permutedNodes, this.pq.compressedVectorSize(), this.partialQuantizedSums, this.delta, this.bestDistance, this.results, this.vsf);
            return this.results;
        }
        int nodeCount = this.results.length();
        for (i = 0; i < this.pq.getSubspaceCount(); ++i) {
            for (int j = 0; j < nodeCount; ++j) {
                this.results.set(j, this.results.get(j) + this.partialSums.get(i * this.pq.getClusterCount() + Byte.toUnsignedInt(permutedNodes.get(i * nodeCount + j))));
            }
        }
        for (i = 0; i < nodeCount; ++i) {
            float result = this.results.get(i);
            ++this.invocations;
            this.updateWorstDistance(result);
            this.results.set(i, this.distanceToScore(result));
        }
        if (this.invocations >= this.invocationThreshold) {
            this.delta = (this.worstDistance - this.bestDistance) / 65535.0f;
            VectorUtil.quantizePartials(this.delta, this.partialSums, this.partialBestDistances, this.partialQuantizedSums);
            this.supportsQuantizedSimilarity = true;
        }
        return this.results;
    }

    @Override
    public boolean supportsEdgeLoadingSimilarity() {
        return true;
    }

    @Override
    public float similarityTo(int node2) {
        return this.esf.similarityTo(node2);
    }

    protected abstract float distanceToScore(float var1);

    protected abstract void updateWorstDistance(float var1);

    public static FusedADCPQDecoder newDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat<?> query, VectorFloat<?> results, VectorSimilarityFunction similarityFunction, ScoreFunction.ExactScoreFunction esf) {
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return new DotProductDecoder(neighbors, pq, query, results, esf);
            }
            case EUCLIDEAN: {
                return new EuclideanDecoder(neighbors, pq, query, results, esf);
            }
            case COSINE: {
                return new CosineDecoder(neighbors, pq, query, results, esf);
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function: " + String.valueOf((Object)similarityFunction));
    }

    static class DotProductDecoder
    extends FusedADCPQDecoder {
        public DotProductDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat<?> query, VectorFloat<?> results, ScoreFunction.ExactScoreFunction esf) {
            super(pq, query, neighbors.maxDegree(), neighbors, results, esf, VectorSimilarityFunction.DOT_PRODUCT);
            this.worstDistance = Float.MAX_VALUE;
        }

        @Override
        protected float distanceToScore(float distance) {
            return (distance + 1.0f) / 2.0f;
        }

        @Override
        protected void updateWorstDistance(float distance) {
            this.worstDistance = Math.min(this.worstDistance, distance);
        }
    }

    static class EuclideanDecoder
    extends FusedADCPQDecoder {
        public EuclideanDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat<?> query, VectorFloat<?> results, ScoreFunction.ExactScoreFunction esf) {
            super(pq, query, neighbors.maxDegree(), neighbors, results, esf, VectorSimilarityFunction.EUCLIDEAN);
            this.worstDistance = 0.0f;
        }

        @Override
        protected float distanceToScore(float distance) {
            return 1.0f / (1.0f + distance);
        }

        @Override
        protected void updateWorstDistance(float distance) {
            this.worstDistance = Math.max(this.worstDistance, distance);
        }
    }

    static class CosineDecoder
    extends FusedADCPQDecoder {
        private final float queryMagnitudeSquared;
        private final VectorFloat<?> partialSquaredMagnitudes;
        private final ByteSequence<?> partialQuantizedSquaredMagnitudes;
        private final float[] resultSumAggregates;
        private final float[] resultMagnitudeAggregates;
        private float minSquaredMagnitude;
        private float squaredMagnitudeDelta;

        protected CosineDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat<?> query, VectorFloat<?> results, ScoreFunction.ExactScoreFunction esf) {
            super(pq, query, neighbors.maxDegree(), neighbors, results, esf, VectorSimilarityFunction.COSINE);
            this.worstDistance = Float.MAX_VALUE;
            this.partialSquaredMagnitudes = pq.partialSquaredMagnitudes().updateAndGet(current -> {
                if (current != null) {
                    this.squaredMagnitudeDelta = pq.squaredMagnitudeDelta;
                    this.minSquaredMagnitude = pq.minSquaredMagnitude;
                    return current;
                }
                float maxMagnitude = 0.0f;
                VectorFloat<?> partialMinMagnitudes = vts.createFloatVector(pq.getSubspaceCount());
                VectorFloat<?> partialSquaredMagnitudes = vts.createFloatVector(pq.getSubspaceCount() * pq.getClusterCount());
                for (int m = 0; m < pq.getSubspaceCount(); ++m) {
                    int size = pq.subvectorSizesAndOffsets[m][0];
                    VectorFloat<?> codebook = pq.codebooks[m];
                    float minPartialMagnitude = Float.POSITIVE_INFINITY;
                    float maxPartialMagnitude = 0.0f;
                    for (int j = 0; j < pq.getClusterCount(); ++j) {
                        float partialMagnitude = VectorUtil.dotProduct(codebook, j * size, codebook, j * size, size);
                        minPartialMagnitude = Math.min(minPartialMagnitude, partialMagnitude);
                        maxPartialMagnitude = Math.max(maxPartialMagnitude, partialMagnitude);
                        partialSquaredMagnitudes.set(m * pq.getClusterCount() + j, partialMagnitude);
                    }
                    partialMinMagnitudes.set(m, minPartialMagnitude);
                    maxMagnitude += maxPartialMagnitude;
                    this.minSquaredMagnitude += minPartialMagnitude;
                }
                this.squaredMagnitudeDelta = (maxMagnitude - this.minSquaredMagnitude) / 65535.0f;
                ByteSequence<?> partialQuantizedSquaredMagnitudes = vts.createByteSequence(pq.getSubspaceCount() * pq.getClusterCount() * 2);
                VectorUtil.quantizePartials(this.squaredMagnitudeDelta, partialSquaredMagnitudes, partialMinMagnitudes, partialQuantizedSquaredMagnitudes);
                pq.squaredMagnitudeDelta = this.squaredMagnitudeDelta;
                pq.minSquaredMagnitude = this.minSquaredMagnitude;
                pq.partialQuantizedSquaredMagnitudes().set(partialQuantizedSquaredMagnitudes);
                return partialSquaredMagnitudes;
            });
            this.partialQuantizedSquaredMagnitudes = pq.partialQuantizedSquaredMagnitudes().get();
            VectorFloat<?> center = pq.globalCentroid;
            float queryMagSum = 0.0f;
            VectorFloat<?> centeredQuery = center == null ? query : VectorUtil.sub(query, center);
            for (int i = 0; i < pq.getSubspaceCount(); ++i) {
                int offset = pq.subvectorSizesAndOffsets[i][1];
                int size = pq.subvectorSizesAndOffsets[i][0];
                VectorFloat<?> codebook = pq.codebooks[i];
                VectorUtil.calculatePartialSums(codebook, i, size, pq.getClusterCount(), centeredQuery, offset, VectorSimilarityFunction.DOT_PRODUCT, this.partialSums, this.partialBestDistances);
                queryMagSum += VectorUtil.dotProduct(centeredQuery, offset, centeredQuery, offset, size);
            }
            this.queryMagnitudeSquared = queryMagSum;
            this.bestDistance = VectorUtil.sum(this.partialBestDistances);
            this.resultSumAggregates = new float[results.length()];
            this.resultMagnitudeAggregates = new float[results.length()];
        }

        @Override
        public VectorFloat<?> edgeLoadingSimilarityTo(int origin) {
            int i;
            ByteSequence<?> permutedNodes = this.neighbors.getPackedNeighbors(origin);
            if (this.supportsQuantizedSimilarity) {
                this.results.zero();
                VectorUtil.bulkShuffleQuantizedSimilarityCosine(permutedNodes, this.pq.compressedVectorSize(), this.partialQuantizedSums, this.delta, this.bestDistance, this.partialQuantizedSquaredMagnitudes, this.squaredMagnitudeDelta, this.minSquaredMagnitude, this.queryMagnitudeSquared, this.results);
                return this.results;
            }
            int nodeCount = this.results.length();
            Arrays.fill(this.resultSumAggregates, 0.0f);
            Arrays.fill(this.resultMagnitudeAggregates, 0.0f);
            for (i = 0; i < this.pq.getSubspaceCount(); ++i) {
                for (int j = 0; j < nodeCount; ++j) {
                    int n = j;
                    this.resultSumAggregates[n] = this.resultSumAggregates[n] + this.partialSums.get(i * this.pq.getClusterCount() + Byte.toUnsignedInt(permutedNodes.get(i * nodeCount + j)));
                    int n2 = j;
                    this.resultMagnitudeAggregates[n2] = this.resultMagnitudeAggregates[n2] + this.partialSquaredMagnitudes.get(i * this.pq.getClusterCount() + Byte.toUnsignedInt(permutedNodes.get(i * nodeCount + j)));
                }
            }
            for (i = 0; i < nodeCount; ++i) {
                this.updateWorstDistance(this.resultSumAggregates[i]);
                float result = this.resultSumAggregates[i] / (float)Math.sqrt(this.resultMagnitudeAggregates[i] * this.queryMagnitudeSquared);
                ++this.invocations;
                this.results.set(i, this.distanceToScore(result));
            }
            if (this.invocations >= this.invocationThreshold) {
                this.delta = (this.worstDistance - this.bestDistance) / 65535.0f;
                VectorUtil.quantizePartials(this.delta, this.partialSums, this.partialBestDistances, this.partialQuantizedSums);
                this.supportsQuantizedSimilarity = true;
            }
            return this.results;
        }

        @Override
        protected float distanceToScore(float distance) {
            return (1.0f + distance) / 2.0f;
        }

        @Override
        protected void updateWorstDistance(float distance) {
            this.worstDistance = Math.min(this.worstDistance, distance);
        }
    }
}

