/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.disk.Io;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.pq.KMeansPlusPlusClusterer;
import io.github.jbellis.jvector.vector.VectorUtil;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class ProductQuantization {
    private static final int CLUSTERS = 256;
    private static final int K_MEANS_ITERATIONS = 15;
    private static final int MAX_PQ_TRAINING_SET_SIZE = 256000;
    private final float[][][] codebooks;
    private final int M;
    private final int originalDimension;
    private final float[] globalCentroid;
    private final int[][] subvectorSizesAndOffsets;

    public static ProductQuantization compute(RandomAccessVectorValues<float[]> ravv, int M, boolean globallyCenter) {
        float[] globalCentroid;
        float P = Math.min(1.0f, 256000.0f / (float)ravv.size());
        int[][] subvectorSizesAndOffsets = ProductQuantization.getSubvectorSizesAndOffsets(ravv.dimension(), M);
        List<Object> vectors = IntStream.range(0, ravv.size()).parallel().filter(i -> ThreadLocalRandom.current().nextFloat() < P).mapToObj(ravv::vectorValue).collect(Collectors.toList());
        if (globallyCenter) {
            globalCentroid = KMeansPlusPlusClusterer.centroidOf(vectors);
            vectors = ((Stream)vectors.stream().parallel()).map(v -> VectorUtil.sub(v, globalCentroid)).collect(Collectors.toList());
        } else {
            globalCentroid = null;
        }
        float[][][] codebooks = ProductQuantization.createCodebooks(vectors, M, subvectorSizesAndOffsets);
        return new ProductQuantization(codebooks, globalCentroid);
    }

    ProductQuantization(float[][][] codebooks, float[] globalCentroid) {
        this.codebooks = codebooks;
        this.globalCentroid = globalCentroid;
        this.M = codebooks.length;
        this.subvectorSizesAndOffsets = new int[this.M][];
        int offset = 0;
        for (int i = 0; i < this.M; ++i) {
            int size = codebooks[i][0].length;
            this.subvectorSizesAndOffsets[i] = new int[]{size, offset};
            offset += size;
        }
        this.originalDimension = Arrays.stream(this.subvectorSizesAndOffsets).mapToInt(m -> m[0]).sum();
    }

    public List<byte[]> encodeAll(List<float[]> vectors) {
        return ((Stream)vectors.stream().parallel()).map(this::encode).collect(Collectors.toList());
    }

    public byte[] encode(float[] vector) {
        if (this.globalCentroid != null) {
            vector = VectorUtil.sub(vector, this.globalCentroid);
        }
        float[] finalVector = vector;
        byte[] encoded = new byte[this.M];
        for (int m = 0; m < this.M; ++m) {
            encoded[m] = (byte)ProductQuantization.closetCentroidIndex(ProductQuantization.getSubVector(finalVector, m, this.subvectorSizesAndOffsets), this.codebooks[m]);
        }
        return encoded;
    }

    public float decodedDotProduct(byte[] encoded, float[] other) {
        float sum = 0.0f;
        for (int m = 0; m < this.M; ++m) {
            int offset = this.subvectorSizesAndOffsets[m][1];
            int centroidIndex = Byte.toUnsignedInt(encoded[m]);
            float[] centroidSubvector = this.codebooks[m][centroidIndex];
            sum += VectorUtil.dotProduct(centroidSubvector, 0, other, offset, centroidSubvector.length);
        }
        return sum;
    }

    public void decode(byte[] encoded, float[] target) {
        for (int m = 0; m < this.M; ++m) {
            int centroidIndex = Byte.toUnsignedInt(encoded[m]);
            float[] centroidSubvector = this.codebooks[m][centroidIndex];
            System.arraycopy(centroidSubvector, 0, target, this.subvectorSizesAndOffsets[m][1], this.subvectorSizesAndOffsets[m][0]);
        }
        if (this.globalCentroid != null) {
            VectorUtil.addInPlace(target, this.globalCentroid);
        }
    }

    public int getOriginalDimension() {
        return this.originalDimension;
    }

    public int getSubspaceCount() {
        return this.M;
    }

    static void printCodebooks(List<List<float[]>> codebooks) {
        List strings = codebooks.stream().map(L -> L.stream().map(ProductQuantization::arraySummary).collect(Collectors.toList())).collect(Collectors.toList());
        System.out.printf("Codebooks: [%s]%n", String.join((CharSequence)"\n ", strings.stream().map(L -> "[" + String.join((CharSequence)", ", L) + "]").collect(Collectors.toList())));
    }

    private static String arraySummary(float[] a) {
        ArrayList<Object> b = new ArrayList<Object>();
        for (int i = 0; i < Math.min(4, a.length); ++i) {
            b.add(String.valueOf(a[i]));
        }
        if (a.length > 4) {
            b.set(3, "... (" + a.length + ")");
        }
        return "[" + String.join((CharSequence)", ", b) + "]";
    }

    static float[][][] createCodebooks(List<float[]> vectors, int M, int[][] subvectorSizeAndOffset) {
        return (float[][][])IntStream.range(0, M).parallel().mapToObj(m -> {
            float[][] subvectors = (float[][])((Stream)vectors.stream().parallel()).map(vector -> ProductQuantization.getSubVector(vector, m, subvectorSizeAndOffset)).toArray(s -> new float[s][]);
            KMeansPlusPlusClusterer clusterer = new KMeansPlusPlusClusterer(subvectors, 256, VectorUtil::squareDistance);
            return clusterer.cluster(15);
        }).toArray(s -> new float[s][][]);
    }

    static int closetCentroidIndex(float[] subvector, float[][] codebook) {
        int index = 0;
        float minDist = 2.1474836E9f;
        for (int i = 0; i < codebook.length; ++i) {
            float dist = VectorUtil.squareDistance(subvector, codebook[i]);
            if (!(dist < minDist)) continue;
            minDist = dist;
            index = i;
        }
        return index;
    }

    static float[] getSubVector(float[] vector, int m, int[][] subvectorSizeAndOffset) {
        float[] subvector = new float[subvectorSizeAndOffset[m][0]];
        System.arraycopy(vector, subvectorSizeAndOffset[m][1], subvector, 0, subvectorSizeAndOffset[m][0]);
        return subvector;
    }

    static int[][] getSubvectorSizesAndOffsets(int dimensions, int M) {
        int[][] sizes = new int[M][];
        int baseSize = dimensions / M;
        int remainder = dimensions % M;
        int offset = 0;
        for (int i = 0; i < M; ++i) {
            int size = baseSize + (i < remainder ? 1 : 0);
            sizes[i] = new int[]{size, offset};
            offset += size;
        }
        return sizes;
    }

    public void write(DataOutput out) throws IOException {
        if (this.globalCentroid == null) {
            out.writeInt(0);
        } else {
            out.writeInt(this.globalCentroid.length);
            Io.writeFloats(out, this.globalCentroid);
        }
        out.writeInt(this.M);
        assert (Arrays.stream(this.subvectorSizesAndOffsets).mapToInt(m -> m[0]).sum() == this.originalDimension);
        assert (this.M == this.subvectorSizesAndOffsets.length);
        for (int[] a : this.subvectorSizesAndOffsets) {
            out.writeInt(a[0]);
        }
        assert (this.codebooks.length == this.M);
        assert (this.codebooks[0].length == 256);
        out.writeInt(this.codebooks[0].length);
        float[][][] fArray = this.codebooks;
        int n = fArray.length;
        for (int i = 0; i < n; ++i) {
            float[][] codebook;
            for (float[] centroid : codebook = fArray[i]) {
                Io.writeFloats(out, centroid);
            }
        }
    }

    public static ProductQuantization load(RandomAccessReader in) throws IOException {
        int globalCentroidLength = in.readInt();
        float[] globalCentroid = null;
        if (globalCentroidLength > 0) {
            globalCentroid = new float[globalCentroidLength];
            in.readFully(globalCentroid);
        }
        int M = in.readInt();
        int[][] subvectorSizes = new int[M][];
        int offset = 0;
        for (int i = 0; i < M; ++i) {
            int size;
            subvectorSizes[i] = new int[2];
            subvectorSizes[i][0] = size = in.readInt();
            subvectorSizes[i][1] = offset += size;
        }
        int clusters = in.readInt();
        float[][][] codebooks = new float[M][][];
        for (int m = 0; m < M; ++m) {
            float[][] codebook = new float[clusters][];
            for (int i = 0; i < clusters; ++i) {
                int n = subvectorSizes[m][0];
                float[] centroid = new float[n];
                in.readFully(centroid);
                codebook[i] = centroid;
            }
            codebooks[m] = codebook;
        }
        return new ProductQuantization(codebooks, globalCentroid);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        ProductQuantization that = (ProductQuantization)o;
        return this.M == that.M && this.originalDimension == that.originalDimension && Arrays.deepEquals((Object[])this.codebooks, (Object[])that.codebooks) && Arrays.equals(this.globalCentroid, that.globalCentroid) && Arrays.deepEquals((Object[])this.subvectorSizesAndOffsets, (Object[])that.subvectorSizesAndOffsets);
    }

    public int hashCode() {
        int result = Objects.hash(this.M, this.originalDimension);
        result = 31 * result + Arrays.deepHashCode((Object[])this.codebooks);
        result = 31 * result + Arrays.hashCode(this.globalCentroid);
        result = 31 * result + Arrays.deepHashCode((Object[])this.subvectorSizesAndOffsets);
        return result;
    }

    public float[] getCenter() {
        return this.globalCentroid;
    }
}

