/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.quantization;

import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.quantization.PQVectors;
import io.github.jbellis.jvector.quantization.ProductQuantization;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class ImmutablePQVectors
extends PQVectors {
    private final int vectorCount;
    private final Map<VectorSimilarityFunction, VectorFloat<?>> codebookPartialSumsMap;

    public ImmutablePQVectors(ProductQuantization pq, ByteSequence<?>[] compressedDataChunks, int vectorCount, int vectorsPerChunk) {
        super(pq);
        this.compressedDataChunks = compressedDataChunks;
        this.vectorCount = vectorCount;
        this.vectorsPerChunk = vectorsPerChunk;
        this.codebookPartialSumsMap = new ConcurrentHashMap();
    }

    @Override
    protected int validChunkCount() {
        return this.compressedDataChunks.length;
    }

    @Override
    public int count() {
        return this.vectorCount;
    }

    private VectorFloat<?> getOrCreateCodebookPartialSums(VectorSimilarityFunction vsf) {
        return this.codebookPartialSumsMap.computeIfAbsent(vsf, this.pq::createCodebookPartialSums);
    }

    @Override
    public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, VectorSimilarityFunction similarityFunction) {
        int subspaceCount = this.pq.getSubspaceCount();
        ByteSequence<?> node1Chunk = this.getChunk(node1);
        int node1Offset = this.getOffsetInChunk(node1);
        int clusterCount = this.pq.getClusterCount();
        VectorFloat<?> codebookPartialSums = this.getOrCreateCodebookPartialSums(similarityFunction);
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return node2 -> {
                    ByteSequence<?> node2Chunk = this.getChunk(node2);
                    int node2Offset = this.getOffsetInChunk(node2);
                    float sum = VectorUtil.assembleAndSumPQ(codebookPartialSums, subspaceCount, node1Chunk, node1Offset, node2Chunk, node2Offset, clusterCount);
                    return (1.0f + sum) / 2.0f;
                };
            }
            case COSINE: {
                float norm1 = VectorUtil.assembleAndSumPQ(codebookPartialSums, subspaceCount, node1Chunk, node1Offset, node1Chunk, node1Offset, clusterCount);
                return node2 -> {
                    ByteSequence<?> node2Chunk = this.getChunk(node2);
                    int node2Offset = this.getOffsetInChunk(node2);
                    float sum = VectorUtil.assembleAndSumPQ(codebookPartialSums, subspaceCount, node1Chunk, node1Offset, node2Chunk, node2Offset, clusterCount);
                    float norm2 = VectorUtil.assembleAndSumPQ(codebookPartialSums, subspaceCount, node2Chunk, node2Offset, node2Chunk, node2Offset, clusterCount);
                    float cosine = sum / (float)Math.sqrt(norm1 * norm2);
                    return (1.0f + cosine) / 2.0f;
                };
            }
            case EUCLIDEAN: {
                return node2 -> {
                    ByteSequence<?> node2Chunk = this.getChunk(node2);
                    int node2Offset = this.getOffsetInChunk(node2);
                    float sum = VectorUtil.assembleAndSumPQ(codebookPartialSums, subspaceCount, node1Chunk, node1Offset, node2Chunk, node2Offset, clusterCount);
                    return 1.0f / (1.0f + sum);
                };
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf((Object)similarityFunction));
    }
}

