/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.quantization;

import io.github.jbellis.jvector.annotations.VisibleForTesting;
import io.github.jbellis.jvector.disk.IndexWriter;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.quantization.ImmutablePQVectors;
import io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer;
import io.github.jbellis.jvector.quantization.PQVectors;
import io.github.jbellis.jvector.quantization.VectorCompressor;
import io.github.jbellis.jvector.util.Accountable;
import io.github.jbellis.jvector.util.MathUtil;
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.SplittableRandom;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.agrona.collections.IntHashSet;

public class ProductQuantization
implements VectorCompressor<ByteSequence<?>>,
Accountable {
    private static final int MAGIC = 1978417170;
    protected static final Logger LOG = Logger.getLogger(ProductQuantization.class.getName());
    private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
    static final int DEFAULT_CLUSTERS = 256;
    static final int K_MEANS_ITERATIONS = 6;
    public static final int MAX_PQ_TRAINING_SET_SIZE = 128000;
    final VectorFloat<?>[] codebooks;
    final int M;
    private final int clusterCount;
    final int originalDimension;
    final VectorFloat<?> globalCentroid;
    final int[][] subvectorSizesAndOffsets;
    final float anisotropicThreshold;
    private final float[][] centroidNormsSquared;
    private final ThreadLocal<VectorFloat<?>> partialSums;
    private final AtomicReference<VectorFloat<?>> partialSquaredMagnitudes;

    public static ProductQuantization compute(RandomAccessVectorValues ravv, int M, int clusterCount, boolean globallyCenter) {
        return ProductQuantization.compute(ravv, M, clusterCount, globallyCenter, -1.0f, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool());
    }

    public static ProductQuantization compute(RandomAccessVectorValues ravv, int M, int clusterCount, boolean globallyCenter, float anisotropicThreshold) {
        return ProductQuantization.compute(ravv, M, clusterCount, globallyCenter, anisotropicThreshold, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool());
    }

    public static ProductQuantization compute(RandomAccessVectorValues ravv, int M, int clusterCount, boolean globallyCenter, float anisotropicThreshold, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) {
        VectorFloat<?> globalCentroid;
        ProductQuantization.checkClusterCount(clusterCount);
        int[][] subvectorSizesAndOffsets = ProductQuantization.getSubvectorSizesAndOffsets(ravv.dimension(), M);
        List vectors = ProductQuantization.extractTrainingVectors(ravv, parallelExecutor);
        if (globallyCenter) {
            globalCentroid = KMeansPlusPlusClusterer.centroidOf(vectors);
            List finalVectors = vectors;
            vectors = (List)((ForkJoinTask)simdExecutor.submit(() -> ((Stream)finalVectors.stream().parallel()).map(v -> VectorUtil.sub(v, globalCentroid)).collect(Collectors.toList()))).join();
        } else {
            globalCentroid = null;
        }
        VectorFloat<?>[] codebooks = ProductQuantization.createCodebooks(vectors, subvectorSizesAndOffsets, clusterCount, anisotropicThreshold, simdExecutor);
        return new ProductQuantization(codebooks, clusterCount, subvectorSizesAndOffsets, globalCentroid, anisotropicThreshold);
    }

    static List<VectorFloat<?>> extractTrainingVectors(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) {
        IntStream ordinalStream;
        if (ravv.size() <= 128000) {
            ordinalStream = IntStream.range(0, ravv.size());
        } else {
            SplittableRandom rng = new SplittableRandom(1L);
            IntHashSet ordinals = new IntHashSet(128000);
            for (int j = ravv.size() - 128000; j < ravv.size(); ++j) {
                int t = rng.nextInt(j + 1);
                if (ordinals.contains(t)) {
                    ordinals.add(j);
                    continue;
                }
                ordinals.add(t);
            }
            int[] ordinalArray = new int[ordinals.size()];
            IntHashSet.IntIterator it = ordinals.iterator();
            for (int i = 0; i < ordinals.size(); ++i) {
                assert (it.hasNext());
                ordinalArray[i] = it.next();
            }
            assert (!it.hasNext());
            ordinalStream = IntStream.of(ordinalArray);
        }
        Supplier<RandomAccessVectorValues> ravvCopy = ravv.threadLocalSupplier();
        return (List)((ForkJoinTask)parallelExecutor.submit(() -> ordinalStream.parallel().mapToObj(arg_0 -> ProductQuantization.lambda$extractTrainingVectors$1((Supplier)ravvCopy, arg_0)).collect(Collectors.toList()))).join();
    }

    public ProductQuantization refine(RandomAccessVectorValues ravv) {
        return this.refine(ravv, 1, -1.0f, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool());
    }

    public ProductQuantization refine(RandomAccessVectorValues ravv, int lloydsRounds, float anisotropicThreshold, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) {
        List vectors;
        if (lloydsRounds < 0) {
            throw new IllegalArgumentException("lloydsRounds must be non-negative");
        }
        int[][] subvectorSizesAndOffsets = ProductQuantization.getSubvectorSizesAndOffsets(ravv.dimension(), this.M);
        List vectorsMutable = ProductQuantization.extractTrainingVectors(ravv, parallelExecutor);
        if (this.globalCentroid != null) {
            vectors = vectorsMutable;
            vectorsMutable = (List)((ForkJoinTask)simdExecutor.submit(() -> ((Stream)vectors.stream().parallel()).map(v -> VectorUtil.sub(v, this.globalCentroid)).collect(Collectors.toList()))).join();
        }
        vectors = vectorsMutable;
        Callable<VectorFloat[]> callable = () -> (VectorFloat[])IntStream.range(0, this.M).parallel().mapToObj(m -> {
            VectorFloat<?>[] subvectors = ProductQuantization.extractSubvectors(vectors, m, subvectorSizesAndOffsets);
            KMeansPlusPlusClusterer clusterer = new KMeansPlusPlusClusterer(subvectors, this.codebooks[m], anisotropicThreshold);
            return clusterer.cluster(anisotropicThreshold == -1.0f ? lloydsRounds : 0, anisotropicThreshold == -1.0f ? 0 : lloydsRounds);
        }).toArray(VectorFloat[]::new);
        VectorFloat[] refinedCodebooks = (VectorFloat[])((ForkJoinTask)simdExecutor.submit(callable)).join();
        return new ProductQuantization(refinedCodebooks, this.clusterCount, subvectorSizesAndOffsets, this.globalCentroid, anisotropicThreshold);
    }

    ProductQuantization(VectorFloat<?>[] codebooks, int clusterCount, int[][] subvectorSizesAndOffsets, VectorFloat<?> globalCentroid, float anisotropicThreshold) {
        ProductQuantization.checkClusterCount(clusterCount);
        this.codebooks = codebooks;
        this.globalCentroid = globalCentroid;
        this.M = codebooks.length;
        this.clusterCount = clusterCount;
        this.subvectorSizesAndOffsets = subvectorSizesAndOffsets;
        this.originalDimension = Arrays.stream(subvectorSizesAndOffsets).mapToInt(m -> m[0]).sum();
        if (globalCentroid != null && globalCentroid.length() != this.originalDimension) {
            String msg = String.format("Global centroid length %d does not match vector dimensionality %d", globalCentroid.length(), this.originalDimension);
            throw new IllegalArgumentException(msg);
        }
        this.anisotropicThreshold = anisotropicThreshold;
        this.partialSums = ThreadLocal.withInitial(() -> vectorTypeSupport.createFloatVector(this.getSubspaceCount() * this.getClusterCount()));
        this.partialSquaredMagnitudes = new AtomicReference<Object>(null);
        this.centroidNormsSquared = new float[this.M][clusterCount];
        for (int i = 0; i < this.M; ++i) {
            for (int j = 0; j < clusterCount; ++j) {
                this.centroidNormsSquared[i][j] = VectorUtil.dotProduct(codebooks[i], j * subvectorSizesAndOffsets[i][0], codebooks[i], j * subvectorSizesAndOffsets[i][0], subvectorSizesAndOffsets[i][0]);
            }
        }
    }

    @Override
    public ImmutablePQVectors createCompressedVectors(Object[] compressedVectors) {
        return new ImmutablePQVectors(this, (ByteSequence[])compressedVectors, compressedVectors.length, 1);
    }

    @Override
    public PQVectors encodeAll(RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) {
        return PQVectors.encodeAndBuild(this, ravv.size(), ravv, simdExecutor);
    }

    private void encodeAnisotropic(VectorFloat<?> vector, ByteSequence<?> result) {
        Residual[][] residuals = this.computeResiduals(vector);
        assert (residuals.length == this.M) : "Residuals length mismatch " + residuals.length + " != " + this.M;
        this.initializeToMinResidualNorms(residuals, result);
        float parallelResidualComponentSum = 0.0f;
        for (int i = 0; i < result.length(); ++i) {
            int centroidIdx = Byte.toUnsignedInt(result.get(i));
            parallelResidualComponentSum += residuals[i][centroidIdx].parallelResidualComponent;
        }
        int MAX_ITERATIONS = 10;
        for (int iter = 0; iter < MAX_ITERATIONS; ++iter) {
            boolean changed = false;
            for (int i = 0; i < residuals.length; ++i) {
                int oldIdx = Byte.toUnsignedInt(result.get(i));
                CoordinateDescentResult cdr = this.optimizeSingleSubspace(residuals[i], oldIdx, parallelResidualComponentSum);
                if (cdr.newCenterIdx == oldIdx) continue;
                parallelResidualComponentSum = cdr.newParallelResidualComponent;
                result.set(i, (byte)cdr.newCenterIdx);
                changed = true;
            }
            if (!changed) break;
        }
    }

    private CoordinateDescentResult optimizeSingleSubspace(Residual[] residuals, int oldIdx, float oldParallelResidualSum) {
        float pcm = KMeansPlusPlusClusterer.computeParallelCostMultiplier(this.anisotropicThreshold, this.originalDimension);
        float oldResidualNormSquared = residuals[oldIdx].residualNormSquared;
        float oldParallelComponent = residuals[oldIdx].parallelResidualComponent;
        float bestCostDelta = 0.0f;
        int bestIndex = oldIdx;
        float bestParallelResidualSum = oldParallelResidualSum;
        for (int thisIdx = 0; thisIdx < residuals.length; ++thisIdx) {
            float residualNormDelta;
            float perpendicularNormDelta;
            float costDelta;
            if (thisIdx == oldIdx) continue;
            Residual rs = residuals[thisIdx];
            float thisParallelResidualSum = oldParallelResidualSum - oldParallelComponent + rs.parallelResidualComponent;
            float parallelNormDelta = MathUtil.square(thisParallelResidualSum) - MathUtil.square(oldParallelResidualSum);
            if (parallelNormDelta > 0.0f || !((costDelta = pcm * parallelNormDelta + (perpendicularNormDelta = (residualNormDelta = rs.residualNormSquared - oldResidualNormSquared) - parallelNormDelta)) < bestCostDelta)) continue;
            bestCostDelta = costDelta;
            bestIndex = thisIdx;
            bestParallelResidualSum = thisParallelResidualSum;
        }
        return new CoordinateDescentResult(bestIndex, bestParallelResidualSum);
    }

    private void initializeToMinResidualNorms(Residual[][] residualStats, ByteSequence<?> dest) {
        for (int i = 0; i < residualStats.length; ++i) {
            int minIndex = -1;
            double minNormSquared = Double.MAX_VALUE;
            for (int j = 0; j < residualStats[i].length; ++j) {
                if (!((double)residualStats[i][j].residualNormSquared < minNormSquared)) continue;
                minNormSquared = residualStats[i][j].residualNormSquared;
                minIndex = j;
            }
            dest.set(i, (byte)minIndex);
        }
    }

    private Residual[][] computeResiduals(VectorFloat<?> vector) {
        Residual[][] residuals = new Residual[this.codebooks.length][];
        float inverseNorm = (float)(1.0 / Math.sqrt(VectorUtil.dotProduct(vector, vector)));
        for (int i = 0; i < this.codebooks.length; ++i) {
            VectorFloat<?> x = ProductQuantization.getSubVector(vector, i, this.subvectorSizesAndOffsets);
            float xNormSquared = VectorUtil.dotProduct(x, x);
            residuals[i] = new Residual[this.clusterCount];
            for (int j = 0; j < this.clusterCount; ++j) {
                residuals[i][j] = this.computeResidual(x, this.codebooks[i], j, this.centroidNormsSquared[i][j], xNormSquared, inverseNorm);
            }
        }
        return residuals;
    }

    private Residual computeResidual(VectorFloat<?> x, VectorFloat<?> centroids, int centroid, float cNormSquared, float xNormSquared, float inverseNorm) {
        float cDotX = VectorUtil.dotProduct(centroids, centroid * x.length(), x, 0, x.length());
        float residualNormSquared = cNormSquared - 2.0f * cDotX + xNormSquared;
        float parallelErrorSubtotal = cDotX - xNormSquared;
        float parallelResidualComponent = MathUtil.square(parallelErrorSubtotal) * inverseNorm;
        return new Residual(residualNormSquared, parallelResidualComponent);
    }

    private void encodeUnweighted(VectorFloat<?> vector, ByteSequence<?> dest) {
        for (int m = 0; m < this.M; ++m) {
            dest.set(m, (byte)this.closestCentroidIndex(vector, m, this.codebooks[m]));
        }
    }

    @Override
    public ByteSequence<?> encode(VectorFloat<?> vector) {
        ByteSequence<?> result = vectorTypeSupport.createByteSequence(this.M);
        this.encodeTo(vector, result);
        return result;
    }

    @Override
    public void encodeTo(VectorFloat<?> vector, ByteSequence<?> dest) {
        if (this.globalCentroid != null) {
            vector = VectorUtil.sub(vector, this.globalCentroid);
        }
        if (this.anisotropicThreshold > -1.0f) {
            this.encodeAnisotropic(vector, dest);
        } else {
            this.encodeUnweighted(vector, dest);
        }
    }

    public void decode(ByteSequence<?> encoded, VectorFloat<?> target) {
        this.decodeCentered(encoded, target);
        if (this.globalCentroid != null) {
            VectorUtil.addInPlace(target, this.globalCentroid);
        }
    }

    void decodeCentered(ByteSequence<?> encoded, VectorFloat<?> target) {
        for (int m = 0; m < this.M; ++m) {
            int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
            target.copyFrom(this.codebooks[m], centroidIndex * this.subvectorSizesAndOffsets[m][0], this.subvectorSizesAndOffsets[m][1], this.subvectorSizesAndOffsets[m][0]);
        }
    }

    public int getSubspaceCount() {
        return this.M;
    }

    public int getClusterCount() {
        return this.clusterCount;
    }

    static VectorFloat<?>[] createCodebooks(List<VectorFloat<?>> vectors, int[][] subvectorSizeAndOffset, int clusters, float anisotropicThreshold, ForkJoinPool simdExecutor) {
        int M = subvectorSizeAndOffset.length;
        Callable<VectorFloat[]> callable = () -> (VectorFloat[])IntStream.range(0, M).parallel().mapToObj(m -> {
            VectorFloat<?>[] subvectors = ProductQuantization.extractSubvectors(vectors, m, subvectorSizeAndOffset);
            KMeansPlusPlusClusterer clusterer = new KMeansPlusPlusClusterer(subvectors, clusters, anisotropicThreshold);
            return clusterer.cluster(6, anisotropicThreshold == -1.0f ? 0 : 6);
        }).toArray(VectorFloat[]::new);
        return (VectorFloat[])((ForkJoinTask)simdExecutor.submit(callable)).join();
    }

    private static VectorFloat<?>[] extractSubvectors(List<VectorFloat<?>> vectors, int m, int[][] subvectorSizeAndOffset) {
        return (VectorFloat[])vectors.stream().map(vector -> ProductQuantization.getSubVector(vector, m, subvectorSizeAndOffset)).toArray(VectorFloat[]::new);
    }

    int closestCentroidIndex(VectorFloat<?> subvector, int m, VectorFloat<?> codebook) {
        int index = 0;
        float minDist = Float.MAX_VALUE;
        int subvectorSize = this.subvectorSizesAndOffsets[m][0];
        int subvectorOffset = this.subvectorSizesAndOffsets[m][1];
        for (int i = 0; i < this.clusterCount; ++i) {
            float dist = VectorUtil.squareL2Distance(subvector, subvectorOffset, codebook, i * subvectorSize, subvectorSize);
            if (!(dist < minDist)) continue;
            minDist = dist;
            index = i;
        }
        return index;
    }

    static VectorFloat<?> getSubVector(VectorFloat<?> vector, int m, int[][] subvectorSizeAndOffset) {
        VectorFloat<?> subvector = vectorTypeSupport.createFloatVector(subvectorSizeAndOffset[m][0]);
        subvector.copyFrom(vector, subvectorSizeAndOffset[m][1], 0, subvectorSizeAndOffset[m][0]);
        return subvector;
    }

    @VisibleForTesting
    static int[][] getSubvectorSizesAndOffsets(int dimensions, int M) {
        if (M > dimensions) {
            throw new IllegalArgumentException("Number of subspaces must be less than or equal to the vector dimension");
        }
        int[][] sizes = new int[M][];
        int baseSize = dimensions / M;
        int remainder = dimensions % M;
        int offset = 0;
        for (int i = 0; i < M; ++i) {
            int size = baseSize + (i < remainder ? 1 : 0);
            sizes[i] = new int[]{size, offset};
            offset += size;
        }
        return sizes;
    }

    VectorFloat<?> reusablePartialSums() {
        return this.partialSums.get();
    }

    AtomicReference<VectorFloat<?>> partialSquaredMagnitudes() {
        return this.partialSquaredMagnitudes;
    }

    @Override
    public void write(IndexWriter out, int version) throws IOException {
        if (version > 6) {
            throw new IllegalArgumentException("Unsupported serialization version " + version);
        }
        if (version < 3 && this.anisotropicThreshold != -1.0f) {
            throw new IllegalArgumentException("Anisotropic threshold is only supported in serialization version 3 and above");
        }
        if (version >= 3) {
            out.writeInt(1978417170);
            out.writeInt(version);
        }
        if (this.globalCentroid == null) {
            out.writeInt(0);
        } else {
            out.writeInt(this.globalCentroid.length());
            vectorTypeSupport.writeFloatVector(out, this.globalCentroid);
        }
        out.writeInt(this.M);
        assert (Arrays.stream(this.subvectorSizesAndOffsets).mapToInt(m -> m[0]).sum() == this.originalDimension);
        assert (this.M == this.subvectorSizesAndOffsets.length);
        for (int[] a : this.subvectorSizesAndOffsets) {
            out.writeInt(a[0]);
        }
        if (version >= 3) {
            out.writeFloat(this.anisotropicThreshold);
        }
        assert (this.codebooks.length == this.M);
        out.writeInt(this.clusterCount);
        for (int i = 0; i < this.M; ++i) {
            VectorFloat<?> codebook = this.codebooks[i];
            assert (codebook.length() == this.clusterCount * this.subvectorSizesAndOffsets[i][0]);
            vectorTypeSupport.writeFloatVector(out, codebook);
        }
    }

    public VectorFloat<?> createCodebookPartialSums(VectorSimilarityFunction vectorSimilarityFunction) {
        VectorFloat<?> partialSums = vectorTypeSupport.createFloatVector(this.getSubspaceCount() * this.getClusterCount() * (this.getClusterCount() + 1) / 2);
        int index = 0;
        for (int m = 0; m < this.M; ++m) {
            int size = this.subvectorSizesAndOffsets[m][0];
            VectorFloat<?> codebook = this.codebooks[m];
            for (int i = 0; i < this.clusterCount; ++i) {
                for (int j = i; j < this.clusterCount; ++j) {
                    float sum = vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN ? VectorUtil.squareL2Distance(codebook, i * size, codebook, j * size, size) : VectorUtil.dotProduct(codebook, i * size, codebook, j * size, size);
                    partialSums.set(index++, sum);
                }
            }
        }
        return partialSums;
    }

    @Override
    public int compressorSize() {
        int size = 0;
        size += 4;
        size += 4;
        size += 4;
        if (this.globalCentroid != null) {
            size += 4 * this.globalCentroid.length();
        }
        size += 4;
        size += 4 * this.M;
        size += 4;
        size += 4;
        for (int i = 0; i < this.M; ++i) {
            size += 4 * this.codebooks[i].length();
        }
        return size;
    }

    public static ProductQuantization load(RandomAccessReader in) throws IOException {
        int globalCentroidLength;
        int version;
        int maybeMagic = in.readInt();
        if (maybeMagic != 1978417170) {
            version = 0;
            globalCentroidLength = maybeMagic;
        } else {
            version = in.readInt();
            globalCentroidLength = in.readInt();
        }
        VectorFloat<?> globalCentroid = null;
        if (globalCentroidLength > 0) {
            globalCentroid = vectorTypeSupport.readFloatVector(in, globalCentroidLength);
        }
        int M = in.readInt();
        int[][] subvectorSizes = new int[M][];
        int offset = 0;
        for (int i = 0; i < M; ++i) {
            int size;
            subvectorSizes[i] = new int[2];
            subvectorSizes[i][0] = size = in.readInt();
            subvectorSizes[i][1] = offset;
            offset += size;
        }
        float anisotropicThreshold = version < 3 ? -1.0f : in.readFloat();
        int clusters = in.readInt();
        VectorFloat[] codebooks = new VectorFloat[M];
        for (int m = 0; m < M; ++m) {
            VectorFloat<?> codebook;
            codebooks[m] = codebook = vectorTypeSupport.readFloatVector(in, clusters * subvectorSizes[m][0]);
        }
        return new ProductQuantization(codebooks, clusters, subvectorSizes, globalCentroid, anisotropicThreshold);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        ProductQuantization that = (ProductQuantization)o;
        return this.M == that.M && this.originalDimension == that.originalDimension && Objects.equals(this.globalCentroid, that.globalCentroid) && Arrays.deepEquals((Object[])this.subvectorSizesAndOffsets, (Object[])that.subvectorSizesAndOffsets) && Arrays.deepEquals(this.codebooks, that.codebooks) && this.anisotropicThreshold == that.anisotropicThreshold;
    }

    public int hashCode() {
        int result = Objects.hash(this.M, this.originalDimension);
        result = 31 * result + Arrays.deepHashCode(this.codebooks);
        result = 31 * result + Objects.hashCode(this.globalCentroid);
        result = 31 * result + Arrays.deepHashCode((Object[])this.subvectorSizesAndOffsets);
        return result;
    }

    public VectorFloat<?> getOrComputeCentroid() {
        if (this.globalCentroid != null) {
            return this.globalCentroid;
        }
        VectorFloat<?> centroid = vectorTypeSupport.createFloatVector(this.originalDimension);
        for (int m = 0; m < this.M; ++m) {
            for (int i = 0; i < this.clusterCount; ++i) {
                int subspaceSize = this.subvectorSizesAndOffsets[m][0];
                VectorFloat<?> subCentroid = vectorTypeSupport.createFloatVector(subspaceSize);
                subCentroid.copyFrom(this.codebooks[m], i * subspaceSize, 0, subspaceSize);
                for (int j = 0; j < subspaceSize; ++j) {
                    int k = this.subvectorSizesAndOffsets[m][1] + j;
                    centroid.set(k, centroid.get(k) + subCentroid.get(j));
                }
            }
        }
        VectorUtil.scale(centroid, 1.0f / (float)this.M);
        return centroid;
    }

    @Override
    public int compressedVectorSize() {
        return this.codebooks.length;
    }

    @Override
    public long ramBytesUsed() {
        long size = 0L;
        for (VectorFloat<?> codebook : this.codebooks) {
            size += codebook.ramBytesUsed();
        }
        return size;
    }

    public String toString() {
        if (this.anisotropicThreshold == -1.0f) {
            return String.format("ProductQuantization(M=%d, clusters=%d, centered=%s)", this.M, this.clusterCount, this.globalCentroid != null);
        }
        return String.format("ProductQuantization(M=%d, clusters=%d, centered=%s, anisotropicT=%.3f, eta=%.1f)", this.M, this.clusterCount, this.globalCentroid != null, Float.valueOf(this.anisotropicThreshold), Float.valueOf(KMeansPlusPlusClusterer.computeParallelCostMultiplier(this.anisotropicThreshold, this.originalDimension)));
    }

    private static void checkClusterCount(int clusterCount) {
        if (clusterCount > 256) {
            throw new IllegalArgumentException("Too many PQ clusters: " + clusterCount + " > 256");
        }
        if (clusterCount < 256) {
            LOG.warning("Using less than 256 PQ clusters will not reduce the memory footprint.");
        }
    }

    public int getOriginalDimension() {
        return this.originalDimension;
    }

    @Override
    public double reconstructionError(VectorFloat<?> vector) {
        ByteSequence<?> code = vectorTypeSupport.createByteSequence(this.M);
        if (this.globalCentroid != null) {
            vector = VectorUtil.sub(vector, this.globalCentroid);
        }
        if (this.anisotropicThreshold > -1.0f) {
            this.encodeAnisotropic(vector, code);
        } else {
            this.encodeUnweighted(vector, code);
        }
        float sum = 0.0f;
        for (int m = 0; m < this.M; ++m) {
            int centroidIndex = Byte.toUnsignedInt(code.get(m));
            int centroidLength = this.subvectorSizesAndOffsets[m][0];
            int centroidOffset = this.subvectorSizesAndOffsets[m][1];
            sum += VectorUtil.squareL2Distance(this.codebooks[m], centroidIndex * centroidLength, vector, centroidOffset, centroidLength);
        }
        return sum / (float)vector.length();
    }

    private static /* synthetic */ VectorFloat lambda$extractTrainingVectors$1(Supplier ravvCopy, int targetOrd) {
        RandomAccessVectorValues localRavv = (RandomAccessVectorValues)ravvCopy.get();
        VectorFloat<?> v = localRavv.getVector(targetOrd);
        return localRavv.isValueShared() ? v.copy() : v;
    }

    private static class Residual {
        final float residualNormSquared;
        final float parallelResidualComponent;

        Residual(float residualNormSquared, float parallelResidualComponent) {
            this.residualNormSquared = residualNormSquared;
            this.parallelResidualComponent = parallelResidualComponent;
        }
    }

    private static class CoordinateDescentResult {
        final int newCenterIdx;
        final float newParallelResidualComponent;

        CoordinateDescentResult(int newCenterIdx, float newParallelResidualComponent) {
            this.newCenterIdx = newCenterIdx;
            this.newParallelResidualComponent = newParallelResidualComponent;
        }
    }
}

