/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.graph.NeighborSimilarity;
import io.github.jbellis.jvector.pq.CompressedVectors;
import io.github.jbellis.jvector.pq.ProductQuantization;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;

abstract class CompressedDecoder
implements NeighborSimilarity.ApproximateScoreFunction {
    protected final CompressedVectors cv;

    protected CompressedDecoder(CompressedVectors cv) {
        this.cv = cv;
    }

    static class CosineDecoder
    extends CompressedDecoder {
        protected final float[] partialSums;
        protected final float[] aMagnitude;
        protected final float bMagnitude;

        public CosineDecoder(CompressedVectors cv, float[] query) {
            super(cv);
            ProductQuantization pq = this.cv.pq;
            this.partialSums = cv.reusablePartialSums();
            this.aMagnitude = cv.reusablePartialMagnitudes();
            float bMagSum = 0.0f;
            float[] center = pq.getCenter();
            float[] centeredQuery = center == null ? query : VectorUtil.sub(query, center);
            for (int m = 0; m < pq.getSubspaceCount(); ++m) {
                int offset = pq.subvectorSizesAndOffsets[m][1];
                for (int j = 0; j < 256; ++j) {
                    float[] centroidSubvector = pq.codebooks[m][j];
                    this.partialSums[m * 256 + j] = VectorUtil.dotProduct(centroidSubvector, 0, centeredQuery, offset, centroidSubvector.length);
                    this.aMagnitude[m * 256 + j] = VectorUtil.dotProduct(centroidSubvector, 0, centroidSubvector, 0, centroidSubvector.length);
                }
                bMagSum += VectorUtil.dotProduct(centeredQuery, offset, centeredQuery, offset, pq.subvectorSizesAndOffsets[m][0]);
            }
            this.bMagnitude = bMagSum;
        }

        @Override
        public float similarityTo(int node2) {
            return (1.0f + this.decodedCosine(node2)) / 2.0f;
        }

        protected float decodedCosine(int node2) {
            float sum = 0.0f;
            float aMag = 0.0f;
            byte[] encoded = this.cv.get(node2);
            for (int m = 0; m < this.partialSums.length; ++m) {
                int centroidIndex = Byte.toUnsignedInt(encoded[m]);
                sum += this.partialSums[m * 256 + centroidIndex];
                aMag += this.aMagnitude[m * 256 + centroidIndex];
            }
            return (float)((double)sum / Math.sqrt(aMag * this.bMagnitude));
        }
    }

    static class EuclideanDecoder
    extends CachingDecoder {
        public EuclideanDecoder(CompressedVectors cv, float[] query) {
            super(cv, query, VectorSimilarityFunction.EUCLIDEAN);
        }

        @Override
        public float similarityTo(int node2) {
            return 1.0f / (1.0f + this.decodedSimilarity(this.cv.get(node2)));
        }
    }

    static class DotProductDecoder
    extends CachingDecoder {
        public DotProductDecoder(CompressedVectors cv, float[] query) {
            super(cv, query, VectorSimilarityFunction.DOT_PRODUCT);
        }

        @Override
        public float similarityTo(int node2) {
            return (1.0f + this.decodedSimilarity(this.cv.get(node2))) / 2.0f;
        }
    }

    protected static abstract class CachingDecoder
    extends CompressedDecoder {
        protected final float[] partialSums;

        protected CachingDecoder(CompressedVectors cv, float[] query, VectorSimilarityFunction vsf) {
            super(cv);
            ProductQuantization pq = this.cv.pq;
            this.partialSums = cv.reusablePartialSums();
            float[] center = pq.getCenter();
            float[] centeredQuery = center == null ? query : VectorUtil.sub(query, center);
            for (int i = 0; i < pq.getSubspaceCount(); ++i) {
                int offset = pq.subvectorSizesAndOffsets[i][1];
                int baseOffset = i * 256;
                block5: for (int j = 0; j < 256; ++j) {
                    float[] centroidSubvector = pq.codebooks[i][j];
                    switch (vsf) {
                        case DOT_PRODUCT: {
                            this.partialSums[baseOffset + j] = VectorUtil.dotProduct(centroidSubvector, 0, centeredQuery, offset, centroidSubvector.length);
                            continue block5;
                        }
                        case EUCLIDEAN: {
                            this.partialSums[baseOffset + j] = VectorUtil.squareDistance(centroidSubvector, 0, centeredQuery, offset, centroidSubvector.length);
                            continue block5;
                        }
                        default: {
                            throw new UnsupportedOperationException("Unsupported similarity function " + String.valueOf((Object)vsf));
                        }
                    }
                }
            }
        }

        protected float decodedSimilarity(byte[] encoded) {
            return VectorUtil.assembleAndSum(this.partialSums, 256, encoded);
        }
    }
}

