/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.quantization;

import io.github.jbellis.jvector.annotations.VisibleForTesting;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.quantization.CompressedVectors;
import io.github.jbellis.jvector.quantization.NVQVectors;
import io.github.jbellis.jvector.quantization.VectorCompressor;
import io.github.jbellis.jvector.util.Accountable;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.function.Supplier;
import java.util.stream.IntStream;

public class NVQuantization
implements VectorCompressor<QuantizedVector>,
Accountable {
    private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
    public final BitsPerDimension bitsPerDimension = BitsPerDimension.EIGHT;
    public final VectorFloat<?> globalMean;
    public final int originalDimension;
    public final int[][] subvectorSizesAndOffsets;
    @VisibleForTesting
    public boolean learn = true;

    private NVQuantization(int[][] subvectorSizesAndOffsets, VectorFloat<?> globalMean) {
        this.globalMean = globalMean;
        this.subvectorSizesAndOffsets = subvectorSizesAndOffsets;
        this.originalDimension = Arrays.stream(subvectorSizesAndOffsets).mapToInt(m -> m[0]).sum();
        if (globalMean.length() != this.originalDimension) {
            String msg = String.format("Global mean length %d does not match vector dimensionality %d", globalMean.length(), this.originalDimension);
            throw new IllegalArgumentException(msg);
        }
    }

    public static NVQuantization compute(RandomAccessVectorValues ravv, int nSubVectors) {
        int[][] subvectorSizesAndOffsets = NVQuantization.getSubvectorSizesAndOffsets(ravv.dimension(), nSubVectors);
        RandomAccessVectorValues ravvCopy = ravv.threadLocalSupplier().get();
        int dim = ravvCopy.getVector(0).length();
        VectorFloat<?> globalMean = vectorTypeSupport.createFloatVector(dim);
        for (int i = 0; i < ravvCopy.size(); ++i) {
            VectorUtil.addInPlace(globalMean, ravvCopy.getVector(i));
        }
        VectorUtil.scale(globalMean, 1.0f / (float)ravvCopy.size());
        return new NVQuantization(subvectorSizesAndOffsets, globalMean);
    }

    @Override
    public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
        return new NVQVectors(this, (QuantizedVector[])compressedVectors);
    }

    @Override
    public NVQVectors encodeAll(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) {
        Supplier<RandomAccessVectorValues> ravvCopy = ravv.threadLocalSupplier();
        return new NVQVectors(this, (QuantizedVector[])((ForkJoinTask)parallelExecutor.submit(() -> (QuantizedVector[])IntStream.range(0, ravv.size()).parallel().mapToObj(arg_0 -> this.lambda$encodeAll$1((Supplier)ravvCopy, arg_0)).toArray(QuantizedVector[]::new))).join());
    }

    @Override
    public QuantizedVector encode(VectorFloat<?> vector) {
        QuantizedVector qv = QuantizedVector.createEmpty(this.subvectorSizesAndOffsets, this.bitsPerDimension);
        this.encodeTo(vector, qv);
        return qv;
    }

    @Override
    public void encodeTo(VectorFloat<?> v, QuantizedVector dest) {
        VectorFloat<?> tempVector = VectorUtil.sub(v, this.globalMean);
        QuantizedVector.quantizeTo(this.getSubVectors(tempVector), this.bitsPerDimension, this.learn, dest);
    }

    public VectorFloat<?>[] getSubVectors(VectorFloat<?> vector) {
        VectorFloat[] subvectors = new VectorFloat[this.subvectorSizesAndOffsets.length];
        for (int i = 0; i < this.subvectorSizesAndOffsets.length; ++i) {
            int size = this.subvectorSizesAndOffsets[i][0];
            int offset = this.subvectorSizesAndOffsets[i][1];
            VectorFloat<?> subvector = vectorTypeSupport.createFloatVector(size);
            subvector.copyFrom(vector, offset, 0, size);
            subvectors[i] = subvector;
        }
        return subvectors;
    }

    static int[][] getSubvectorSizesAndOffsets(int dimensions, int M) {
        if (M > dimensions) {
            throw new IllegalArgumentException("Number of subspaces must be less than or equal to the vector dimension");
        }
        int[][] sizes = new int[M][2];
        int baseSize = dimensions / M;
        int remainder = dimensions % M;
        int offset = 0;
        for (int i = 0; i < M; ++i) {
            int size = baseSize + (i < remainder ? 1 : 0);
            sizes[i] = new int[]{size, offset};
            offset += size;
        }
        return sizes;
    }

    @Override
    public void write(DataOutput out, int version) throws IOException {
        if (version > 4) {
            throw new IllegalArgumentException("Unsupported serialization version " + version);
        }
        out.writeInt(version);
        out.writeInt(this.globalMean.length());
        vectorTypeSupport.writeFloatVector(out, this.globalMean);
        this.bitsPerDimension.write(out);
        out.writeInt(this.subvectorSizesAndOffsets.length);
        assert (Arrays.stream(this.subvectorSizesAndOffsets).mapToInt(m -> m[0]).sum() == this.originalDimension);
        for (int[] a : this.subvectorSizesAndOffsets) {
            out.writeInt(a[0]);
        }
    }

    @Override
    public int compressorSize() {
        int size = 0;
        size += 4;
        size += 4;
        size += 4 * this.globalMean.length();
        size += 4;
        size += 4;
        return size += this.subvectorSizesAndOffsets.length * 4;
    }

    public static NVQuantization load(RandomAccessReader in) throws IOException {
        int version = in.readInt();
        int globalMeanLength = in.readInt();
        VectorFloat<?> globalMean = null;
        if (globalMeanLength > 0) {
            globalMean = vectorTypeSupport.readFloatVector(in, globalMeanLength);
        }
        BitsPerDimension bitsPerDimension = BitsPerDimension.load(in);
        int nSubVectors = in.readInt();
        int[][] subvectorSizes = new int[nSubVectors][];
        int offset = 0;
        for (int i = 0; i < nSubVectors; ++i) {
            int size;
            subvectorSizes[i] = new int[2];
            subvectorSizes[i][0] = size = in.readInt();
            subvectorSizes[i][1] = offset;
            offset += size;
        }
        return new NVQuantization(subvectorSizes, globalMean);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        NVQuantization that = (NVQuantization)o;
        return this.originalDimension == that.originalDimension && Objects.equals(this.globalMean, that.globalMean) && Arrays.deepEquals((Object[])this.subvectorSizesAndOffsets, (Object[])that.subvectorSizesAndOffsets);
    }

    public int hashCode() {
        int result = Objects.hash(this.originalDimension);
        result = 31 * result + Objects.hashCode(this.globalMean);
        result = 31 * result + Arrays.deepHashCode((Object[])this.subvectorSizesAndOffsets);
        return result;
    }

    @Override
    public int compressedVectorSize() {
        int size = 4;
        for (int[] subvectorSizesAndOffset : this.subvectorSizesAndOffsets) {
            size += QuantizedSubVector.compressedVectorSize(subvectorSizesAndOffset[0], this.bitsPerDimension);
        }
        return size;
    }

    @Override
    public long ramBytesUsed() {
        return this.globalMean.ramBytesUsed();
    }

    public String toString() {
        return String.format("NVQuantization(sub-vectors=%d)", this.subvectorSizesAndOffsets.length);
    }

    private /* synthetic */ QuantizedVector lambda$encodeAll$1(Supplier ravvCopy, int i) {
        RandomAccessVectorValues localRavv = (RandomAccessVectorValues)ravvCopy.get();
        VectorFloat<?> v = localRavv.getVector(i);
        return this.encode((VectorFloat)v);
    }

    public static enum BitsPerDimension {
        EIGHT{

            @Override
            public int getInt() {
                return 8;
            }

            @Override
            public ByteSequence<?> createByteSequence(int nDimensions) {
                return vectorTypeSupport.createByteSequence(nDimensions);
            }
        }
        ,
        FOUR{

            @Override
            public int getInt() {
                return 4;
            }

            @Override
            public ByteSequence<?> createByteSequence(int nDimensions) {
                return vectorTypeSupport.createByteSequence((int)Math.ceil((double)nDimensions / 2.0));
            }
        };


        public void write(DataOutput out) throws IOException {
            out.writeInt(this.getInt());
        }

        public abstract int getInt();

        public abstract ByteSequence<?> createByteSequence(int var1);

        public static BitsPerDimension load(RandomAccessReader in) throws IOException {
            int nBitsPerDimension = in.readInt();
            switch (nBitsPerDimension) {
                case 8: {
                    return EIGHT;
                }
            }
            throw new IllegalArgumentException("Unsupported BitsPerDimension " + nBitsPerDimension);
        }
    }

    public static class QuantizedVector {
        public final QuantizedSubVector[] subVectors;

        public static void quantizeTo(VectorFloat<?>[] subVectors, BitsPerDimension bitsPerDimension, boolean learn, QuantizedVector dest) {
            for (int i = 0; i < subVectors.length; ++i) {
                QuantizedSubVector.quantizeTo(subVectors[i], bitsPerDimension, learn, dest.subVectors[i]);
            }
        }

        private QuantizedVector(QuantizedSubVector[] subVectors) {
            this.subVectors = subVectors;
        }

        public static QuantizedVector createEmpty(int[][] subvectorSizesAndOffsets, BitsPerDimension bitsPerDimension) {
            QuantizedSubVector[] subVectors = new QuantizedSubVector[subvectorSizesAndOffsets.length];
            for (int i = 0; i < subvectorSizesAndOffsets.length; ++i) {
                subVectors[i] = QuantizedSubVector.createEmpty(bitsPerDimension, subvectorSizesAndOffsets[i][0]);
            }
            return new QuantizedVector(subVectors);
        }

        public void write(DataOutput out) throws IOException {
            out.writeInt(this.subVectors.length);
            for (QuantizedSubVector sv : this.subVectors) {
                sv.write(out);
            }
        }

        public static QuantizedVector load(RandomAccessReader in) throws IOException {
            int length = in.readInt();
            QuantizedSubVector[] subVectors = new QuantizedSubVector[length];
            for (int i = 0; i < length; ++i) {
                subVectors[i] = QuantizedSubVector.load(in);
            }
            return new QuantizedVector(subVectors);
        }

        public static void loadInto(RandomAccessReader in, QuantizedVector qvector) throws IOException {
            in.readInt();
            for (int i = 0; i < qvector.subVectors.length; ++i) {
                QuantizedSubVector.loadInto(in, qvector.subVectors[i]);
            }
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            QuantizedVector that = (QuantizedVector)o;
            return Arrays.deepEquals(this.subVectors, that.subVectors);
        }
    }

    public static class QuantizedSubVector {
        public ByteSequence<?> bytes;
        public BitsPerDimension bitsPerDimension;
        public float growthRate;
        public float midpoint;
        public float maxValue;
        public float minValue;
        public int originalDimensions;

        public static int compressedVectorSize(int nDims, BitsPerDimension bitsPerDimension) {
            switch (bitsPerDimension.ordinal()) {
                case 0: {
                    return nDims + 16 + 12;
                }
            }
            throw new IllegalArgumentException("Unsupported bits per dimension: " + String.valueOf((Object)bitsPerDimension));
        }

        public static void quantizeTo(VectorFloat<?> vector, BitsPerDimension bitsPerDimension, boolean learn, QuantizedSubVector dest) {
            float minValue = VectorUtil.min(vector);
            float maxValue = VectorUtil.max(vector);
            float growthRate = 0.01f;
            float midpoint = 0.0f;
            if (learn) {
                NonuniformQuantizationLossFunction lossFunction = new NonuniformQuantizationLossFunction(bitsPerDimension);
                lossFunction.setVector(vector, minValue, maxValue);
                float growthRateCoarse = 0.01f;
                float bestLossValue = Float.MIN_VALUE;
                float[] tempSolution = new float[]{growthRateCoarse, 0.0f};
                for (float gr = 1.0E-6f; gr < 20.0f; gr += 1.0f) {
                    tempSolution[0] = gr;
                    float lossValue = lossFunction.compute(tempSolution);
                    if (!(lossValue > bestLossValue)) continue;
                    bestLossValue = lossValue;
                    growthRateCoarse = gr;
                }
                float growthRateFineTuned = growthRateCoarse;
                for (float gr = growthRateCoarse - 1.0f; gr < growthRateCoarse + 1.0f; gr += 0.1f) {
                    tempSolution[0] = gr;
                    float lossValue = lossFunction.compute(tempSolution);
                    if (!(lossValue > bestLossValue)) continue;
                    bestLossValue = lossValue;
                    growthRateFineTuned = gr;
                }
                growthRate = growthRateFineTuned;
            }
            ByteSequence<?> quantized = bitsPerDimension.createByteSequence(vector.length());
            switch (bitsPerDimension.ordinal()) {
                case 0: {
                    VectorUtil.nvqQuantize8bit(vector, growthRate, midpoint, minValue, maxValue, quantized);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Unsupported bits per dimension: " + String.valueOf((Object)bitsPerDimension));
                }
            }
            dest.bitsPerDimension = bitsPerDimension;
            dest.minValue = minValue;
            dest.maxValue = maxValue;
            dest.growthRate = growthRate;
            dest.midpoint = midpoint;
            dest.bytes = quantized;
            dest.originalDimensions = vector.length();
        }

        private QuantizedSubVector(ByteSequence<?> bytes, int originalDimensions, BitsPerDimension bitsPerDimension, float minValue, float maxValue, float growthRate, float midpoint) {
            this.bitsPerDimension = bitsPerDimension;
            this.bytes = bytes;
            this.minValue = minValue;
            this.maxValue = maxValue;
            this.growthRate = growthRate;
            this.midpoint = midpoint;
            this.originalDimensions = originalDimensions;
        }

        public void write(DataOutput out) throws IOException {
            this.bitsPerDimension.write(out);
            out.writeFloat(this.minValue);
            out.writeFloat(this.maxValue);
            out.writeFloat(this.growthRate);
            out.writeFloat(this.midpoint);
            out.writeInt(this.originalDimensions);
            out.writeInt(this.bytes.length());
            vectorTypeSupport.writeByteSequence(out, this.bytes);
        }

        public static QuantizedSubVector createEmpty(BitsPerDimension bitsPerDimension, int length) {
            ByteSequence<?> bytes = bitsPerDimension.createByteSequence(length);
            return new QuantizedSubVector(bytes, length, bitsPerDimension, 0.0f, 0.0f, 0.0f, 0.0f);
        }

        public static QuantizedSubVector load(RandomAccessReader in) throws IOException {
            BitsPerDimension bitsPerDimension = BitsPerDimension.load(in);
            float minValue = in.readFloat();
            float maxValue = in.readFloat();
            float logisticAlpha = in.readFloat();
            float logisticX0 = in.readFloat();
            int originalDimensions = in.readInt();
            int compressedDimension = in.readInt();
            ByteSequence<?> bytes = vectorTypeSupport.readByteSequence(in, compressedDimension);
            return new QuantizedSubVector(bytes, originalDimensions, bitsPerDimension, minValue, maxValue, logisticAlpha, logisticX0);
        }

        public static void loadInto(RandomAccessReader in, QuantizedSubVector quantizedSubVector) throws IOException {
            quantizedSubVector.bitsPerDimension = BitsPerDimension.load(in);
            quantizedSubVector.minValue = in.readFloat();
            quantizedSubVector.maxValue = in.readFloat();
            quantizedSubVector.growthRate = in.readFloat();
            quantizedSubVector.midpoint = in.readFloat();
            quantizedSubVector.originalDimensions = in.readInt();
            in.readInt();
            vectorTypeSupport.readByteSequence(in, quantizedSubVector.bytes);
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            QuantizedSubVector that = (QuantizedSubVector)o;
            return this.maxValue == that.maxValue && this.minValue == that.minValue && this.growthRate == that.growthRate && this.midpoint == that.midpoint && this.bitsPerDimension == that.bitsPerDimension && this.bytes.equals(that.bytes);
        }
    }

    private static class NonuniformQuantizationLossFunction {
        private final BitsPerDimension bitsPerDimension;
        private VectorFloat<?> vector;
        private float minValue;
        private float maxValue;
        private float baseline;

        public NonuniformQuantizationLossFunction(BitsPerDimension bitsPerDimension) {
            this.bitsPerDimension = bitsPerDimension;
        }

        public void setVector(VectorFloat<?> vector, float minValue, float maxValue) {
            this.vector = vector;
            this.minValue = minValue;
            this.maxValue = maxValue;
            this.baseline = VectorUtil.nvqUniformLoss(vector, minValue, maxValue, this.bitsPerDimension.getInt());
        }

        public float computeRaw(float[] x) {
            return VectorUtil.nvqLoss(this.vector, x[0], x[1], this.minValue, this.maxValue, this.bitsPerDimension.getInt());
        }

        public float compute(float[] x) {
            return this.baseline / this.computeRaw(x);
        }
    }
}

