/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.quantization;

import io.github.jbellis.jvector.graph.disk.feature.FusedFeature;
import io.github.jbellis.jvector.graph.disk.feature.FusedPQ;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.quantization.ProductQuantization;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import org.agrona.collections.Int2ObjectHashMap;

public abstract class FusedPQDecoder
implements ScoreFunction.ApproximateScoreFunction {
    private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
    protected final ProductQuantization pq;
    Int2ObjectHashMap<FusedFeature.InlineSource> hierarchyCachedFeatures;
    protected final VectorFloat<?> query;
    protected final ScoreFunction.ExactScoreFunction esf;
    protected final FusedPQ.PackedNeighbors packedNeighbors;
    protected final ByteSequence<?> neighborCodes;
    protected final VectorFloat<?> partialSums;
    protected final VectorSimilarityFunction vsf;
    protected int origin;

    protected FusedPQDecoder(ProductQuantization pq, Int2ObjectHashMap<FusedFeature.InlineSource> hierarchyCachedFeatures, VectorFloat<?> query, FusedPQ.PackedNeighbors packedNeighbors, ByteSequence<?> neighborCodes, VectorFloat<?> results, ScoreFunction.ExactScoreFunction esf, VectorSimilarityFunction vsf) {
        this.pq = pq;
        this.hierarchyCachedFeatures = hierarchyCachedFeatures;
        this.query = query;
        this.esf = esf;
        this.packedNeighbors = packedNeighbors;
        this.neighborCodes = neighborCodes;
        this.vsf = vsf;
        this.origin = -1;
        this.partialSums = pq.reusablePartialSums();
        if (vsf != VectorSimilarityFunction.COSINE) {
            VectorFloat<?> center = pq.globalCentroid;
            VectorFloat<?> centeredQuery = center == null ? query : VectorUtil.sub(query, center);
            for (int i = 0; i < pq.getSubspaceCount(); ++i) {
                int offset = pq.subvectorSizesAndOffsets[i][1];
                int size = pq.subvectorSizesAndOffsets[i][0];
                VectorFloat<?> codebook = pq.codebooks[i];
                VectorUtil.calculatePartialSums(codebook, i, size, pq.getClusterCount(), centeredQuery, offset, vsf, this.partialSums);
            }
        }
    }

    @Override
    public boolean supportsSimilarityToNeighbors() {
        return true;
    }

    @Override
    public void enableSimilarityToNeighbors(int origin) {
        if (this.origin != origin) {
            this.origin = origin;
            this.packedNeighbors.readInto(origin, this.neighborCodes);
        }
    }

    @Override
    public float similarityTo(int node2) {
        if (!this.hierarchyCachedFeatures.containsKey(node2)) {
            throw new IllegalArgumentException("Node " + node2 + " is not in the hierarchy");
        }
        FusedPQ.FusedPQInlineSource code2 = (FusedPQ.FusedPQInlineSource)this.hierarchyCachedFeatures.get(node2);
        float sim = VectorUtil.assembleAndSum(this.partialSums, this.pq.getClusterCount(), code2.getCode());
        return this.distanceToScore(sim);
    }

    @Override
    public float similarityToNeighbor(int origin, int neighborIndex) {
        if (this.origin != origin) {
            throw new IllegalArgumentException("origin must be the same as the origin used to enable similarityToNeighbor");
        }
        int position = neighborIndex * this.pq.getSubspaceCount();
        float sim = VectorUtil.assembleAndSum(this.partialSums, this.pq.getClusterCount(), this.neighborCodes, position, this.pq.getSubspaceCount());
        return this.distanceToScore(sim);
    }

    protected abstract float distanceToScore(float var1);

    public static FusedPQDecoder newDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq, Int2ObjectHashMap<FusedFeature.InlineSource> hierarchyCachedFeatures, VectorFloat<?> query, ByteSequence<?> reusableNeighborCodes, VectorFloat<?> results, VectorSimilarityFunction similarityFunction, ScoreFunction.ExactScoreFunction esf) {
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return new DotProductDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, esf);
            }
            case EUCLIDEAN: {
                return new EuclideanDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, esf);
            }
            case COSINE: {
                return new CosineDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, esf);
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function: " + String.valueOf((Object)similarityFunction));
    }

    static class DotProductDecoder
    extends FusedPQDecoder {
        public DotProductDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq, Int2ObjectHashMap<FusedFeature.InlineSource> hierarchyCachedFeatures, VectorFloat<?> query, ByteSequence<?> neighborCodes, VectorFloat<?> results, ScoreFunction.ExactScoreFunction esf) {
            super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, esf, VectorSimilarityFunction.DOT_PRODUCT);
        }

        @Override
        protected float distanceToScore(float distance) {
            return (distance + 1.0f) / 2.0f;
        }
    }

    static class EuclideanDecoder
    extends FusedPQDecoder {
        public EuclideanDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq, Int2ObjectHashMap<FusedFeature.InlineSource> hierarchyCachedFeatures, VectorFloat<?> query, ByteSequence<?> neighborCodes, VectorFloat<?> results, ScoreFunction.ExactScoreFunction esf) {
            super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, esf, VectorSimilarityFunction.EUCLIDEAN);
        }

        @Override
        protected float distanceToScore(float distance) {
            return 1.0f / (1.0f + distance);
        }
    }

    static class CosineDecoder
    extends FusedPQDecoder {
        private final float queryMagnitudeSquared;
        private final VectorFloat<?> partialSquaredMagnitudes;

        protected CosineDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq, Int2ObjectHashMap<FusedFeature.InlineSource> hierarchyCachedFeatures, VectorFloat<?> query, ByteSequence<?> neighborCodes, VectorFloat<?> results, ScoreFunction.ExactScoreFunction esf) {
            super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, esf, VectorSimilarityFunction.COSINE);
            this.partialSquaredMagnitudes = pq.partialSquaredMagnitudes().updateAndGet(current -> {
                if (current != null) {
                    return current;
                }
                VectorFloat<?> partialSquaredMagnitudes = vts.createFloatVector(pq.getSubspaceCount() * pq.getClusterCount());
                for (int m = 0; m < pq.getSubspaceCount(); ++m) {
                    int size = pq.subvectorSizesAndOffsets[m][0];
                    VectorFloat<?> codebook = pq.codebooks[m];
                    float minPartialMagnitude = Float.POSITIVE_INFINITY;
                    float maxPartialMagnitude = 0.0f;
                    for (int j = 0; j < pq.getClusterCount(); ++j) {
                        float partialMagnitude = VectorUtil.dotProduct(codebook, j * size, codebook, j * size, size);
                        minPartialMagnitude = Math.min(minPartialMagnitude, partialMagnitude);
                        maxPartialMagnitude = Math.max(maxPartialMagnitude, partialMagnitude);
                        partialSquaredMagnitudes.set(m * pq.getClusterCount() + j, partialMagnitude);
                    }
                }
                return partialSquaredMagnitudes;
            });
            VectorFloat<?> center = pq.globalCentroid;
            float queryMagSum = 0.0f;
            VectorFloat<?> centeredQuery = center == null ? query : VectorUtil.sub(query, center);
            for (int i = 0; i < pq.getSubspaceCount(); ++i) {
                int offset = pq.subvectorSizesAndOffsets[i][1];
                int size = pq.subvectorSizesAndOffsets[i][0];
                VectorFloat<?> codebook = pq.codebooks[i];
                VectorUtil.calculatePartialSums(codebook, i, size, pq.getClusterCount(), centeredQuery, offset, VectorSimilarityFunction.DOT_PRODUCT, this.partialSums);
                queryMagSum += VectorUtil.dotProduct(centeredQuery, offset, centeredQuery, offset, size);
            }
            this.queryMagnitudeSquared = queryMagSum;
        }

        @Override
        public float similarityTo(int node2) {
            if (!this.hierarchyCachedFeatures.containsKey(node2)) {
                throw new IllegalArgumentException("Node " + node2 + " is not in the hierarchy");
            }
            FusedPQ.FusedPQInlineSource code2 = (FusedPQ.FusedPQInlineSource)this.hierarchyCachedFeatures.get(node2);
            float cos = VectorUtil.pqDecodedCosineSimilarity(code2.getCode(), 0, this.pq.getSubspaceCount(), this.pq.getClusterCount(), this.partialSums, this.partialSquaredMagnitudes, this.queryMagnitudeSquared);
            return this.distanceToScore(cos);
        }

        @Override
        public float similarityToNeighbor(int origin, int neighborIndex) {
            if (this.origin != origin) {
                throw new IllegalArgumentException("origin must be the same as the origin used to enable similarityToNeighbor");
            }
            int position = neighborIndex * this.pq.getSubspaceCount();
            float cos = VectorUtil.pqDecodedCosineSimilarity(this.neighborCodes, position, this.pq.getSubspaceCount(), this.pq.getClusterCount(), this.partialSums, this.partialSquaredMagnitudes, this.queryMagnitudeSquared);
            return this.distanceToScore(cos);
        }

        @Override
        protected float distanceToScore(float distance) {
            return (1.0f + distance) / 2.0f;
        }
    }
}

