/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.pq.CompressedVectors;
import io.github.jbellis.jvector.pq.PQDecoder;
import io.github.jbellis.jvector.pq.ProductQuantization;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

public class PQVectors
implements CompressedVectors {
    private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
    final ProductQuantization pq;
    private final List<ByteSequence<?>> compressedVectors;

    public PQVectors(ProductQuantization pq, List<ByteSequence<?>> compressedVectors) {
        this.pq = pq;
        this.compressedVectors = compressedVectors;
    }

    public PQVectors(ProductQuantization pq, ByteSequence<?>[] compressedVectors) {
        this(pq, List.of(compressedVectors));
    }

    @Override
    public int count() {
        return this.compressedVectors.size();
    }

    @Override
    public void write(DataOutput out, int version) throws IOException {
        this.pq.write(out, version);
        out.writeInt(this.compressedVectors.size());
        out.writeInt(this.pq.getSubspaceCount());
        for (ByteSequence<?> v : this.compressedVectors) {
            vectorTypeSupport.writeByteSequence(out, v);
        }
    }

    public static PQVectors load(RandomAccessReader in) throws IOException {
        ProductQuantization pq = ProductQuantization.load(in);
        int size = in.readInt();
        if (size < 0) {
            throw new IOException("Invalid compressed vector count " + size);
        }
        ArrayList compressedVectors = new ArrayList(size);
        int compressedDimension = in.readInt();
        if (compressedDimension < 0) {
            throw new IOException("Invalid compressed vector dimension " + compressedDimension);
        }
        for (int i = 0; i < size; ++i) {
            ByteSequence<?> vector = vectorTypeSupport.readByteSequence(in, compressedDimension);
            compressedVectors.add(vector);
        }
        return new PQVectors(pq, compressedVectors);
    }

    public static PQVectors load(RandomAccessReader in, long offset) throws IOException {
        in.seek(offset);
        return PQVectors.load(in);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        PQVectors that = (PQVectors)o;
        if (!Objects.equals(this.pq, that.pq)) {
            return false;
        }
        return Objects.equals(this.compressedVectors, that.compressedVectors);
    }

    public int hashCode() {
        return Objects.hash(this.pq, this.compressedVectors);
    }

    @Override
    public ScoreFunction.ApproximateScoreFunction precomputedScoreFunctionFor(VectorFloat<?> q, VectorSimilarityFunction similarityFunction) {
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return new PQDecoder.DotProductDecoder(this, q);
            }
            case EUCLIDEAN: {
                return new PQDecoder.EuclideanDecoder(this, q);
            }
            case COSINE: {
                return new PQDecoder.CosineDecoder(this, q);
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf((Object)similarityFunction));
    }

    @Override
    public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q, VectorSimilarityFunction similarityFunction) {
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return node2 -> {
                    ByteSequence<?> encoded = this.get(node2);
                    float dp = 0.0f;
                    for (int m = 0; m < this.pq.getSubspaceCount(); ++m) {
                        int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
                        int centroidLength = this.pq.subvectorSizesAndOffsets[m][0];
                        int centroidOffset = this.pq.subvectorSizesAndOffsets[m][1];
                        dp += VectorUtil.dotProduct(this.pq.codebooks[m], centroidIndex * centroidLength, q, centroidOffset, centroidLength);
                    }
                    return (1.0f + dp) / 2.0f;
                };
            }
            case COSINE: {
                float norm1 = VectorUtil.dotProduct(q, q);
                return node2 -> {
                    ByteSequence<?> encoded = this.get(node2);
                    float sum = 0.0f;
                    float norm2 = 0.0f;
                    for (int m = 0; m < this.pq.getSubspaceCount(); ++m) {
                        int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
                        int centroidLength = this.pq.subvectorSizesAndOffsets[m][0];
                        int centroidOffset = this.pq.subvectorSizesAndOffsets[m][1];
                        int codebookOffset = centroidIndex * centroidLength;
                        sum += VectorUtil.dotProduct(this.pq.codebooks[m], codebookOffset, q, centroidOffset, centroidLength);
                        norm2 += VectorUtil.dotProduct(this.pq.codebooks[m], codebookOffset, this.pq.codebooks[m], codebookOffset, centroidLength);
                    }
                    float cosine = sum / (float)Math.sqrt(norm1 * norm2);
                    return (1.0f + cosine) / 2.0f;
                };
            }
            case EUCLIDEAN: {
                return node2 -> {
                    ByteSequence<?> encoded = this.get(node2);
                    float sum = 0.0f;
                    for (int m = 0; m < this.pq.getSubspaceCount(); ++m) {
                        int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
                        int centroidLength = this.pq.subvectorSizesAndOffsets[m][0];
                        int centroidOffset = this.pq.subvectorSizesAndOffsets[m][1];
                        sum += VectorUtil.squareL2Distance(this.pq.codebooks[m], centroidIndex * centroidLength, q, centroidOffset, centroidLength);
                    }
                    return 1.0f / (1.0f + sum);
                };
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf((Object)similarityFunction));
    }

    public ByteSequence<?> get(int ordinal) {
        return this.compressedVectors.get(ordinal);
    }

    public ProductQuantization getProductQuantization() {
        return this.pq;
    }

    VectorFloat<?> reusablePartialSums() {
        return this.pq.reusablePartialSums();
    }

    AtomicReference<VectorFloat<?>> partialSquaredMagnitudes() {
        return this.pq.partialSquaredMagnitudes();
    }

    @Override
    public int getOriginalSize() {
        return this.pq.originalDimension * 4;
    }

    @Override
    public int getCompressedSize() {
        return this.pq.compressedVectorSize();
    }

    public ProductQuantization getCompressor() {
        return this.pq;
    }

    @Override
    public long ramBytesUsed() {
        int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF;
        int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER;
        int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;
        long codebooksSize = this.pq.ramBytesUsed();
        long listSize = (long)REF_BYTES * (long)(1 + this.compressedVectors.size());
        long dataSize = (long)(OH_BYTES + AH_BYTES + this.pq.compressedVectorSize()) * (long)this.compressedVectors.size();
        return codebooksSize + listSize + dataSize;
    }

    public String toString() {
        return "PQVectors{pq=" + String.valueOf(this.pq) + ", count=" + this.compressedVectors.size() + "}";
    }
}

