/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.quantization;

import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.quantization.CompressedVectors;
import io.github.jbellis.jvector.quantization.ImmutablePQVectors;
import io.github.jbellis.jvector.quantization.PQDecoder;
import io.github.jbellis.jvector.quantization.ProductQuantization;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Objects;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.stream.IntStream;

public abstract class PQVectors
implements CompressedVectors {
    private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
    final ProductQuantization pq;
    protected ByteSequence<?>[] compressedDataChunks;
    protected int vectorsPerChunk;

    protected PQVectors(ProductQuantization pq) {
        this.pq = pq;
    }

    public static ImmutablePQVectors load(RandomAccessReader in) throws IOException {
        ProductQuantization pq = ProductQuantization.load(in);
        int vectorCount = in.readInt();
        int compressedDimension = in.readInt();
        PQLayout layout = new PQLayout(vectorCount, compressedDimension);
        ByteSequence[] chunks = new ByteSequence[layout.totalChunks];
        for (int i = 0; i < layout.fullSizeChunks; ++i) {
            chunks[i] = vectorTypeSupport.readByteSequence(in, layout.fullChunkBytes);
        }
        if (layout.totalChunks > layout.fullSizeChunks) {
            chunks[layout.fullSizeChunks] = vectorTypeSupport.readByteSequence(in, layout.lastChunkBytes);
        }
        return new ImmutablePQVectors(pq, chunks, vectorCount, layout.fullChunkVectors);
    }

    public static PQVectors load(RandomAccessReader in, long offset) throws IOException {
        in.seek(offset);
        return PQVectors.load(in);
    }

    public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vectorCount, RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) {
        int compressedDimension = pq.compressedVectorSize();
        PQLayout layout = new PQLayout(vectorCount, compressedDimension);
        ByteSequence[] chunks = new ByteSequence[layout.totalChunks];
        for (int i = 0; i < layout.fullSizeChunks; ++i) {
            chunks[i] = vectorTypeSupport.createByteSequence(layout.fullChunkBytes);
        }
        if (layout.lastChunkVectors > 0) {
            chunks[layout.fullSizeChunks] = vectorTypeSupport.createByteSequence(layout.lastChunkBytes);
        }
        Supplier<RandomAccessVectorValues> ravvCopy = ravv.threadLocalSupplier();
        ((ForkJoinTask)simdExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel().forEach(arg_0 -> PQVectors.lambda$encodeAndBuild$1((Supplier)ravvCopy, chunks, layout, pq, arg_0)))).join();
        return new ImmutablePQVectors(pq, chunks, vectorCount, layout.fullChunkVectors);
    }

    @Override
    public void write(DataOutput out, int version) throws IOException {
        this.pq.write(out, version);
        out.writeInt(this.count());
        out.writeInt(this.pq.getSubspaceCount());
        for (int i = 0; i < this.validChunkCount(); ++i) {
            vectorTypeSupport.writeByteSequence(out, this.compressedDataChunks[i]);
        }
    }

    protected abstract int validChunkCount();

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        PQVectors that = (PQVectors)o;
        if (!Objects.equals(this.pq, that.pq)) {
            return false;
        }
        if (this.count() != that.count()) {
            return false;
        }
        for (int i = 0; i < this.count(); ++i) {
            ByteSequence<?> thatNode;
            ByteSequence<?> thisNode = this.get(i);
            if (thisNode.equals(thatNode = that.get(i))) continue;
            return false;
        }
        return true;
    }

    public int hashCode() {
        int result = 1;
        result = 31 * result + this.pq.hashCode();
        result = 31 * result + this.count();
        for (int i = 0; i < this.count(); ++i) {
            result = 31 * result + this.get(i).hashCode();
        }
        return result;
    }

    @Override
    public ScoreFunction.ApproximateScoreFunction precomputedScoreFunctionFor(VectorFloat<?> q, VectorSimilarityFunction similarityFunction) {
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return new PQDecoder.DotProductDecoder(this, q);
            }
            case EUCLIDEAN: {
                return new PQDecoder.EuclideanDecoder(this, q);
            }
            case COSINE: {
                return new PQDecoder.CosineDecoder(this, q);
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf((Object)similarityFunction));
    }

    @Override
    public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q, VectorSimilarityFunction similarityFunction) {
        VectorFloat<?> centeredQuery = this.pq.globalCentroid == null ? q : VectorUtil.sub(q, this.pq.globalCentroid);
        int subspaceCount = this.pq.getSubspaceCount();
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return node2 -> {
                    ByteSequence<?> encodedChunk = this.getChunk(node2);
                    int encodedOffset = this.getOffsetInChunk(node2);
                    float dp = 0.0f;
                    for (int m = 0; m < subspaceCount; ++m) {
                        int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
                        int centroidLength = this.pq.subvectorSizesAndOffsets[m][0];
                        int centroidOffset = this.pq.subvectorSizesAndOffsets[m][1];
                        dp += VectorUtil.dotProduct(this.pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
                    }
                    return (1.0f + dp) / 2.0f;
                };
            }
            case COSINE: {
                float norm1 = VectorUtil.dotProduct(centeredQuery, centeredQuery);
                return node2 -> {
                    ByteSequence<?> encodedChunk = this.getChunk(node2);
                    int encodedOffset = this.getOffsetInChunk(node2);
                    float sum = 0.0f;
                    float norm2 = 0.0f;
                    for (int m = 0; m < subspaceCount; ++m) {
                        int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
                        int centroidLength = this.pq.subvectorSizesAndOffsets[m][0];
                        int centroidOffset = this.pq.subvectorSizesAndOffsets[m][1];
                        int codebookOffset = centroidIndex * centroidLength;
                        sum += VectorUtil.dotProduct(this.pq.codebooks[m], codebookOffset, centeredQuery, centroidOffset, centroidLength);
                        norm2 += VectorUtil.dotProduct(this.pq.codebooks[m], codebookOffset, this.pq.codebooks[m], codebookOffset, centroidLength);
                    }
                    float cosine = sum / (float)Math.sqrt(norm1 * norm2);
                    return (1.0f + cosine) / 2.0f;
                };
            }
            case EUCLIDEAN: {
                return node2 -> {
                    ByteSequence<?> encodedChunk = this.getChunk(node2);
                    int encodedOffset = this.getOffsetInChunk(node2);
                    float sum = 0.0f;
                    for (int m = 0; m < subspaceCount; ++m) {
                        int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
                        int centroidLength = this.pq.subvectorSizesAndOffsets[m][0];
                        int centroidOffset = this.pq.subvectorSizesAndOffsets[m][1];
                        sum += VectorUtil.squareL2Distance(this.pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
                    }
                    return 1.0f / (1.0f + sum);
                };
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf((Object)similarityFunction));
    }

    @Override
    public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, VectorSimilarityFunction similarityFunction) {
        int subspaceCount = this.pq.getSubspaceCount();
        ByteSequence<?> node1Chunk = this.getChunk(node1);
        int node1Offset = this.getOffsetInChunk(node1);
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return node2 -> {
                    ByteSequence<?> node2Chunk = this.getChunk(node2);
                    int node2Offset = this.getOffsetInChunk(node2);
                    float dp = 0.0f;
                    for (int m = 0; m < subspaceCount; ++m) {
                        int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
                        int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
                        int centroidLength = this.pq.subvectorSizesAndOffsets[m][0];
                        dp += VectorUtil.dotProduct(this.pq.codebooks[m], centroidIndex1 * centroidLength, this.pq.codebooks[m], centroidIndex2 * centroidLength, centroidLength);
                    }
                    return (1.0f + dp) / 2.0f;
                };
            }
            case COSINE: {
                float norm1 = 0.0f;
                for (int m1 = 0; m1 < subspaceCount; ++m1) {
                    int centroidIndex = Byte.toUnsignedInt(node1Chunk.get(m1 + node1Offset));
                    int centroidLength = this.pq.subvectorSizesAndOffsets[m1][0];
                    int codebookOffset = centroidIndex * centroidLength;
                    norm1 += VectorUtil.dotProduct(this.pq.codebooks[m1], codebookOffset, this.pq.codebooks[m1], codebookOffset, centroidLength);
                }
                float norm1final = norm1;
                return node2 -> {
                    ByteSequence<?> node2Chunk = this.getChunk(node2);
                    int node2Offset = this.getOffsetInChunk(node2);
                    float sum = 0.0f;
                    float norm2 = 0.0f;
                    for (int m = 0; m < subspaceCount; ++m) {
                        int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
                        int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
                        int centroidLength = this.pq.subvectorSizesAndOffsets[m][0];
                        int codebookOffset = centroidIndex2 * centroidLength;
                        sum += VectorUtil.dotProduct(this.pq.codebooks[m], codebookOffset, this.pq.codebooks[m], centroidIndex1 * centroidLength, centroidLength);
                        norm2 += VectorUtil.dotProduct(this.pq.codebooks[m], codebookOffset, this.pq.codebooks[m], codebookOffset, centroidLength);
                    }
                    float cosine = sum / (float)Math.sqrt(norm1final * norm2);
                    return (1.0f + cosine) / 2.0f;
                };
            }
            case EUCLIDEAN: {
                return node2 -> {
                    ByteSequence<?> node2Chunk = this.getChunk(node2);
                    int node2Offset = this.getOffsetInChunk(node2);
                    float sum = 0.0f;
                    for (int m = 0; m < subspaceCount; ++m) {
                        int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
                        int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
                        int centroidLength = this.pq.subvectorSizesAndOffsets[m][0];
                        sum += VectorUtil.squareL2Distance(this.pq.codebooks[m], centroidIndex1 * centroidLength, this.pq.codebooks[m], centroidIndex2 * centroidLength, centroidLength);
                    }
                    return 1.0f / (1.0f + sum);
                };
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf((Object)similarityFunction));
    }

    public ByteSequence<?> get(int ordinal) {
        if (ordinal < 0 || ordinal >= this.count()) {
            throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + this.count());
        }
        return PQVectors.get(this.compressedDataChunks, ordinal, this.vectorsPerChunk, this.pq.getSubspaceCount());
    }

    static ByteSequence<?> get(ByteSequence<?>[] chunks, int ordinal, int vectorsPerChunk, int subspaceCount) {
        int vectorIndexInChunk = ordinal % vectorsPerChunk;
        int start = vectorIndexInChunk * subspaceCount;
        return PQVectors.getChunk(chunks, ordinal, vectorsPerChunk).slice(start, subspaceCount);
    }

    ByteSequence<?> getChunk(int ordinal) {
        if (ordinal < 0 || ordinal >= this.count()) {
            throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + this.count());
        }
        return PQVectors.getChunk(this.compressedDataChunks, ordinal, this.vectorsPerChunk);
    }

    int getOffsetInChunk(int ordinal) {
        if (ordinal < 0 || ordinal >= this.count()) {
            throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + this.count());
        }
        int vectorIndexInChunk = ordinal % this.vectorsPerChunk;
        return vectorIndexInChunk * this.pq.getSubspaceCount();
    }

    static ByteSequence<?> getChunk(ByteSequence<?>[] chunks, int ordinal, int vectorsPerChunk) {
        int chunkIndex = ordinal / vectorsPerChunk;
        return chunks[chunkIndex];
    }

    VectorFloat<?> reusablePartialSums() {
        return this.pq.reusablePartialSums();
    }

    AtomicReference<VectorFloat<?>> partialSquaredMagnitudes() {
        return this.pq.partialSquaredMagnitudes();
    }

    @Override
    public int getOriginalSize() {
        return this.pq.originalDimension * 4;
    }

    @Override
    public int getCompressedSize() {
        return this.pq.compressedVectorSize();
    }

    public ProductQuantization getCompressor() {
        return this.pq;
    }

    @Override
    public long ramBytesUsed() {
        int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF;
        int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER;
        int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;
        long codebooksSize = this.pq.ramBytesUsed();
        long chunksArraySize = (long)(OH_BYTES + AH_BYTES) + (long)this.validChunkCount() * (long)REF_BYTES;
        long dataSize = 0L;
        for (int i = 0; i < this.validChunkCount(); ++i) {
            dataSize += this.compressedDataChunks[i].ramBytesUsed();
        }
        return codebooksSize + chunksArraySize + dataSize;
    }

    public String toString() {
        return "PQVectors{pq=" + String.valueOf(this.pq) + ", count=" + this.count() + "}";
    }

    private static /* synthetic */ void lambda$encodeAndBuild$1(Supplier ravvCopy, ByteSequence[] chunks, PQLayout layout, ProductQuantization pq, int ordinal) {
        RandomAccessVectorValues localRavv = (RandomAccessVectorValues)ravvCopy.get();
        ByteSequence<?> slice = PQVectors.get(chunks, ordinal, layout.fullChunkVectors, pq.getSubspaceCount());
        VectorFloat<?> vector = localRavv.getVector(ordinal);
        if (vector != null) {
            pq.encodeTo(vector, slice);
        } else {
            slice.zero();
        }
    }

    static class PQLayout {
        public final int vectorCount;
        public final int totalChunks;
        public final int fullSizeChunks;
        public final int fullChunkVectors;
        public final int lastChunkVectors;
        public final int compressedDimension;
        public final int fullChunkBytes;
        public final int lastChunkBytes;

        public PQLayout(int vectorCount, int compressedDimension) {
            if (vectorCount <= 0) {
                throw new IllegalArgumentException("Invalid vector count " + vectorCount);
            }
            this.vectorCount = vectorCount;
            if (compressedDimension <= 0) {
                throw new IllegalArgumentException("Invalid compressed dimension " + compressedDimension);
            }
            this.compressedDimension = compressedDimension;
            int layoutBytesPerVector = compressedDimension == 1 ? 1 : Integer.highestOneBit(compressedDimension - 1) << 1;
            int addressableVectorsPerChunk = Integer.MAX_VALUE / layoutBytesPerVector;
            this.fullChunkVectors = Math.min(vectorCount, addressableVectorsPerChunk);
            this.lastChunkVectors = vectorCount % this.fullChunkVectors;
            this.fullChunkBytes = this.fullChunkVectors * compressedDimension;
            this.lastChunkBytes = this.lastChunkVectors * compressedDimension;
            this.fullSizeChunks = vectorCount / this.fullChunkVectors;
            this.totalChunks = this.fullSizeChunks + (this.lastChunkVectors == 0 ? 0 : 1);
        }
    }
}

