/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.disk.LVQPackedVectors;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.pq.CompressedVectors;
import io.github.jbellis.jvector.pq.KMeansPlusPlusClusterer;
import io.github.jbellis.jvector.pq.VectorCompressor;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.stream.Stream;

public class LocallyAdaptiveVectorQuantization
implements VectorCompressor<QuantizedVector> {
    private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
    public final VectorFloat<?> globalMean;

    public LocallyAdaptiveVectorQuantization(VectorFloat<?> globalMean) {
        this.globalMean = globalMean;
    }

    public static LocallyAdaptiveVectorQuantization compute(RandomAccessVectorValues ravv) {
        RandomAccessVectorValues ravvCopy = ravv.threadLocalSupplier().get();
        ArrayList list = new ArrayList(ravvCopy.size());
        for (int i = 0; i < ravvCopy.size(); ++i) {
            list.add(ravvCopy.vectorValue(i));
        }
        return new LocallyAdaptiveVectorQuantization(KMeansPlusPlusClusterer.centroidOf(list));
    }

    public QuantizedVector[] encodeAll(List<VectorFloat<?>> vectors, ForkJoinPool simdExecutor) {
        return (QuantizedVector[])((ForkJoinTask)simdExecutor.submit(() -> (QuantizedVector[])((Stream)vectors.stream().parallel()).map(vectorFloat -> this.encode((VectorFloat)vectorFloat)).toArray(QuantizedVector[]::new))).join();
    }

    @Override
    public QuantizedVector encode(VectorFloat<?> v) {
        VectorFloat<?> vCentered = VectorUtil.sub(v, this.globalMean);
        float u = VectorUtil.max(vCentered);
        float l = VectorUtil.min(vCentered);
        ByteSequence<?> quantized = vectorTypeSupport.createByteSequence(vCentered.length());
        for (int i = 0; i < vCentered.length(); ++i) {
            quantized.set(i, LocallyAdaptiveVectorQuantization.quantizeFloatToByte(vCentered.get(i), l, u));
        }
        return new QuantizedVector(quantized, l, (u - l) / 255.0f);
    }

    private static byte quantizeFloatToByte(float value, float minFloat, float maxFloat) {
        float delta = (maxFloat - minFloat) / 255.0f;
        int quantizedValue = Math.round((value - minFloat) / delta);
        if (quantizedValue < 0) {
            quantizedValue = 0;
        }
        if (quantizedValue > 255) {
            quantizedValue = 255;
        }
        return (byte)quantizedValue;
    }

    @Override
    public void write(DataOutput out) throws IOException {
        out.writeInt(this.globalMean.length());
        vectorTypeSupport.writeFloatVector(out, this.globalMean);
    }

    @Override
    public int compressedVectorSize() {
        int lvqDimension = this.globalMean.length() % 64 == 0 ? this.globalMean.length() : (this.globalMean.length() / 64 + 1) * 64;
        return lvqDimension + 8;
    }

    @Override
    public int compressorSize() {
        return 4 + 4 * this.globalMean.length();
    }

    private ScoreFunction.ExactScoreFunction dotProductScoreFunctionFrom(final VectorFloat<?> query, final LVQPackedVectors packedVectors) {
        final float querySum = VectorUtil.sum(query);
        final float queryGlobalBias = VectorUtil.dotProduct(query, this.globalMean);
        return new ScoreFunction.ExactScoreFunction(){
            final /* synthetic */ LocallyAdaptiveVectorQuantization this$0;
            {
                this.this$0 = this$0;
            }

            @Override
            public VectorFloat<?> similarityTo(int[] nodes) {
                VectorFloat<?> results = vts.createFloatVector(nodes.length);
                int nodeCount = nodes.length;
                for (int i = 0; i < nodeCount; ++i) {
                    int node = nodes[i];
                    PackedVector vector = packedVectors.getPackedVector(node);
                    float lvqDot = VectorUtil.lvqDotProduct(query, vector, querySum);
                    results.set(i, (1.0f + (lvqDot += queryGlobalBias)) / 2.0f);
                }
                return results;
            }

            @Override
            public float similarityTo(int node2) {
                PackedVector vector = packedVectors.getPackedVector(node2);
                float lvqDot = VectorUtil.lvqDotProduct(query, vector, querySum);
                return (1.0f + (lvqDot += queryGlobalBias)) / 2.0f;
            }
        };
    }

    private ScoreFunction.ExactScoreFunction euclideanScoreFunctionFrom(VectorFloat<?> query, final LVQPackedVectors packedVectors) {
        final VectorFloat<?> shiftedQuery = VectorUtil.sub(query, this.globalMean);
        return new ScoreFunction.ExactScoreFunction(){
            final /* synthetic */ LocallyAdaptiveVectorQuantization this$0;
            {
                this.this$0 = this$0;
            }

            @Override
            public VectorFloat<?> similarityTo(int[] nodes) {
                VectorFloat<?> results = vts.createFloatVector(nodes.length);
                int nodeCount = nodes.length;
                for (int i = 0; i < nodeCount; ++i) {
                    int node = nodes[i];
                    PackedVector vector = packedVectors.getPackedVector(node);
                    float lvqDist = VectorUtil.lvqSquareL2Distance(shiftedQuery, vector);
                    results.set(i, 1.0f / (1.0f + lvqDist));
                }
                return results;
            }

            @Override
            public float similarityTo(int node2) {
                PackedVector vector = packedVectors.getPackedVector(node2);
                float lvqDist = VectorUtil.lvqSquareL2Distance(shiftedQuery, vector);
                return 1.0f / (1.0f + lvqDist);
            }
        };
    }

    private ScoreFunction.ExactScoreFunction cosineScoreFunctionFrom(final VectorFloat<?> query, final LVQPackedVectors packedVectors) {
        return new ScoreFunction.ExactScoreFunction(){
            final /* synthetic */ LocallyAdaptiveVectorQuantization this$0;
            {
                this.this$0 = this$0;
            }

            @Override
            public VectorFloat<?> similarityTo(int[] nodes) {
                VectorFloat<?> results = vts.createFloatVector(nodes.length);
                int nodeCount = nodes.length;
                for (int i = 0; i < nodeCount; ++i) {
                    int node = nodes[i];
                    PackedVector vector = packedVectors.getPackedVector(node);
                    float lvqCosine = VectorUtil.lvqCosine(query, vector, this.this$0.globalMean);
                    results.set(i, (1.0f + lvqCosine) / 2.0f);
                }
                return results;
            }

            @Override
            public float similarityTo(int node2) {
                PackedVector vector = packedVectors.getPackedVector(node2);
                float lvqCosine = VectorUtil.lvqCosine(query, vector, this.this$0.globalMean);
                return (1.0f + lvqCosine) / 2.0f;
            }
        };
    }

    public ScoreFunction.ExactScoreFunction scoreFunctionFrom(VectorFloat<?> query, VectorSimilarityFunction similarityFunction, LVQPackedVectors packedVectors) {
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return this.dotProductScoreFunctionFrom(query, packedVectors);
            }
            case EUCLIDEAN: {
                return this.euclideanScoreFunctionFrom(query, packedVectors);
            }
            case COSINE: {
                return this.cosineScoreFunctionFrom(query, packedVectors);
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function: " + String.valueOf((Object)similarityFunction));
    }

    @Override
    public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
        throw new UnsupportedOperationException("LVQ does not produce a compressed vectors implementation");
    }

    public static LocallyAdaptiveVectorQuantization load(RandomAccessReader in) throws IOException {
        int length = in.readInt();
        VectorFloat<?> globalMean = vectorTypeSupport.readFloatVector(in, length);
        return new LocallyAdaptiveVectorQuantization(globalMean);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        LocallyAdaptiveVectorQuantization that = (LocallyAdaptiveVectorQuantization)o;
        return this.globalMean.equals(that.globalMean);
    }

    public static class QuantizedVector {
        private final ByteSequence<?> bytes;
        private final float bias;
        private final float scale;

        public QuantizedVector(ByteSequence<?> bytes, float bias, float scale) {
            this.bytes = bytes;
            this.bias = bias;
            this.scale = scale;
        }

        private void writeByteSafely(DataOutput out, ByteSequence<?> encodedVector, int index) throws IOException {
            if (index < encodedVector.length()) {
                out.writeByte(encodedVector.get(index));
            } else {
                out.writeByte(0);
            }
        }

        public void writePacked(DataOutput out) throws IOException {
            int startIndex;
            int i;
            out.writeFloat(this.bias);
            out.writeFloat(this.scale);
            int mainBlockCount = this.bytes.length() / 64;
            for (i = 0; i < mainBlockCount; ++i) {
                for (int j = startIndex = i * 64; j < startIndex + 16; ++j) {
                    out.writeByte(this.bytes.get(j));
                    out.writeByte(this.bytes.get(j + 16));
                    out.writeByte(this.bytes.get(j + 32));
                    out.writeByte(this.bytes.get(j + 48));
                }
            }
            startIndex = i * 64;
            if (startIndex < this.bytes.length()) {
                int j;
                int endIndex = Math.min(startIndex + 16, this.bytes.length());
                for (j = startIndex; j < endIndex; ++j) {
                    this.writeByteSafely(out, this.bytes, j);
                    this.writeByteSafely(out, this.bytes, j + 16);
                    this.writeByteSafely(out, this.bytes, j + 32);
                    this.writeByteSafely(out, this.bytes, j + 48);
                }
                while (j < startIndex + 16) {
                    out.writeInt(0);
                    ++j;
                }
            }
        }
    }

    public static class PackedVector {
        public final ByteSequence<?> bytes;
        public final float bias;
        public final float scale;

        public PackedVector(ByteSequence<?> bytes, float bias, float scale) {
            this.bytes = bytes;
            this.bias = bias;
            this.scale = scale;
        }

        public int getQuantized(int index) {
            int blockId = index / 64;
            int inBlockId = index % 64;
            int laneId = inBlockId % 16;
            int laneOffset = inBlockId / 16;
            int packedIndex = blockId * 64 + laneId * 4 + laneOffset;
            return Byte.toUnsignedInt(this.bytes.get(packedIndex));
        }

        public float getDequantized(int index) {
            return (float)this.getQuantized(index) * this.scale + this.bias;
        }

        public PackedVector copy() {
            return new PackedVector(this.bytes.copy(), this.bias, this.scale);
        }
    }
}

