/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.util.MathUtil;
import io.github.jbellis.jvector.vector.Matrix;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;

public class KMeansPlusPlusClusterer {
    private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
    public static final float UNWEIGHTED = -1.0f;
    private final int k;
    private final VectorFloat<?>[] points;
    private final int[] assignments;
    private final VectorFloat<?> centroids;
    private final float anisotropicThreshold;
    private final int[] centroidDenoms;
    private final VectorFloat<?>[] centroidNums;

    public KMeansPlusPlusClusterer(VectorFloat<?>[] points, int k) {
        this(points, KMeansPlusPlusClusterer.chooseInitialCentroids(points, k), -1.0f);
    }

    public KMeansPlusPlusClusterer(VectorFloat<?>[] points, int k, float anisotropicThreshold) {
        this(points, KMeansPlusPlusClusterer.chooseInitialCentroids(points, k), anisotropicThreshold);
    }

    public KMeansPlusPlusClusterer(VectorFloat<?>[] points, VectorFloat<?> centroids, float anisotropicThreshold) {
        if (Float.isNaN(anisotropicThreshold) || (double)anisotropicThreshold < -1.0 || (double)anisotropicThreshold >= 1.0) {
            throw new IllegalArgumentException("Valid range for anisotropic threshold T is -1.0 <= t < 1.0");
        }
        this.points = points;
        this.k = centroids.length() / points[0].length();
        this.centroids = centroids.copy();
        this.anisotropicThreshold = anisotropicThreshold;
        this.centroidDenoms = new int[this.k];
        this.centroidNums = new VectorFloat[this.k];
        for (int i = 0; i < this.centroidNums.length; ++i) {
            this.centroidNums[i] = vectorTypeSupport.createFloatVector(points[0].length());
        }
        this.assignments = new int[points.length];
        this.initializeAssignedPoints();
    }

    static float computeParallelCostMultiplier(double threshold, int dimensions) {
        assert (Double.isFinite(threshold)) : "threshold=" + threshold;
        double parallelCost = threshold * threshold;
        double perpendicularCost = (1.0 - parallelCost) / (double)(dimensions - 1);
        return (float)Math.max(1.0, parallelCost / perpendicularCost);
    }

    public VectorFloat<?> cluster(int unweightedIterations, int anisotropicIterations) {
        int changedCount;
        int i;
        for (i = 0; i < unweightedIterations && !((double)(changedCount = this.clusterOnceUnweighted()) <= 0.01 * (double)this.points.length); ++i) {
        }
        for (i = 0; i < anisotropicIterations && !((double)(changedCount = this.clusterOnceAnisotropic()) <= 0.01 * (double)this.points.length); ++i) {
        }
        return this.centroids;
    }

    public int clusterOnceUnweighted() {
        this.updateCentroidsUnweighted();
        return this.updateAssignedPointsUnweighted();
    }

    public int clusterOnceAnisotropic() {
        this.updateCentroidsAnisotropic();
        return this.updateAssignedPointsAnisotropic();
    }

    private static VectorFloat<?> chooseInitialCentroids(VectorFloat<?>[] points, int k) {
        int i;
        if (k <= 0) {
            throw new IllegalArgumentException("Number of clusters must be positive.");
        }
        if (k > points.length) {
            throw new IllegalArgumentException(String.format("Number of clusters %d cannot exceed number of points %d", k, points.length));
        }
        ThreadLocalRandom random = ThreadLocalRandom.current();
        VectorFloat<?> centroids = vectorTypeSupport.createFloatVector(k * points[0].length());
        float[] distances = new float[points.length];
        Arrays.fill(distances, Float.MAX_VALUE);
        VectorFloat<?> firstCentroid = points[random.nextInt(points.length)];
        centroids.copyFrom(firstCentroid, 0, 0, firstCentroid.length());
        for (i = 0; i < points.length; ++i) {
            float distance1 = VectorUtil.squareL2Distance(points[i], firstCentroid);
            distances[i] = Math.min(distances[i], distance1);
        }
        for (i = 1; i < k; ++i) {
            float totalDistance = 0.0f;
            for (float distance : distances) {
                totalDistance += distance;
            }
            float r = random.nextFloat() * totalDistance;
            int selectedIdx = -1;
            for (int j = 0; j < distances.length; ++j) {
                if (!((double)(r -= distances[j]) < 1.0E-6)) continue;
                selectedIdx = j;
                break;
            }
            if (selectedIdx == -1) {
                selectedIdx = random.nextInt(points.length);
            }
            VectorFloat<?> nextCentroid = points[selectedIdx];
            centroids.copyFrom(nextCentroid, 0, i * nextCentroid.length(), nextCentroid.length());
            for (int j = 0; j < points.length; ++j) {
                float newDistance = VectorUtil.squareL2Distance(points[j], nextCentroid);
                distances[j] = Math.min(distances[j], newDistance);
            }
        }
        KMeansPlusPlusClusterer.assertFinite(centroids);
        return centroids;
    }

    private void initializeAssignedPoints() {
        for (int i = 0; i < this.points.length; ++i) {
            VectorFloat<?> point = this.points[i];
            int newAssignment = this.getNearestCluster(point);
            this.centroidDenoms[newAssignment] = this.centroidDenoms[newAssignment] + 1;
            VectorUtil.addInPlace(this.centroidNums[newAssignment], point);
            this.assignments[i] = newAssignment;
        }
    }

    private int updateAssignedPointsUnweighted() {
        int changedCount = 0;
        for (int i = 0; i < this.points.length; ++i) {
            VectorFloat<?> point = this.points[i];
            int oldAssignment = this.assignments[i];
            int newAssignment = this.getNearestCluster(point);
            if (newAssignment == oldAssignment) continue;
            this.centroidDenoms[oldAssignment] = this.centroidDenoms[oldAssignment] - 1;
            VectorUtil.subInPlace(this.centroidNums[oldAssignment], point);
            this.centroidDenoms[newAssignment] = this.centroidDenoms[newAssignment] + 1;
            VectorUtil.addInPlace(this.centroidNums[newAssignment], point);
            this.assignments[i] = newAssignment;
            ++changedCount;
        }
        return changedCount;
    }

    private int updateAssignedPointsAnisotropic() {
        float pcm = KMeansPlusPlusClusterer.computeParallelCostMultiplier(this.anisotropicThreshold, this.points[0].length());
        float[] cNormSquared = new float[this.k];
        for (int i = 0; i < this.k; ++i) {
            cNormSquared[i] = VectorUtil.dotProduct(this.centroids, i * this.points[0].length(), this.centroids, i * this.points[0].length(), this.points[0].length());
        }
        int changedCount = 0;
        for (int i = 0; i < this.points.length; ++i) {
            VectorFloat<?> x = this.points[i];
            float xNormSquared = VectorUtil.dotProduct(x, x);
            int index = this.assignments[i];
            float minDist = Float.MAX_VALUE;
            for (int j = 0; j < this.k; ++j) {
                float dist = this.weightedDistance(x, j, pcm, cNormSquared[j], xNormSquared);
                if (!(dist < minDist)) continue;
                minDist = dist;
                index = j;
            }
            if (index == this.assignments[i]) continue;
            ++changedCount;
            this.assignments[i] = index;
        }
        return changedCount;
    }

    private float weightedDistance(VectorFloat<?> x, int centroid, float parallelCostMultiplier, float cNormSquared, float xNormSquared) {
        float cDotX = VectorUtil.dotProduct(this.centroids, centroid * x.length(), x, 0, x.length());
        float parallelErrorSubtotal = cDotX - xNormSquared;
        float residualSquaredNorm = cNormSquared - 2.0f * cDotX + xNormSquared;
        float parallelError = MathUtil.square(parallelErrorSubtotal);
        float perpendicularError = residualSquaredNorm - parallelError;
        return parallelCostMultiplier * parallelError + perpendicularError;
    }

    private int getNearestCluster(VectorFloat<?> point) {
        float minDistance = Float.MAX_VALUE;
        int nearestCluster = 0;
        for (int i = 0; i < this.k; ++i) {
            float distance = VectorUtil.squareL2Distance(point, 0, this.centroids, i * point.length(), point.length());
            if (!(distance < minDistance)) continue;
            minDistance = distance;
            nearestCluster = i;
        }
        return nearestCluster;
    }

    private static void assertFinite(VectorFloat<?> vector) {
        boolean assertsEnabled = false;
        if (!$assertionsDisabled) {
            assertsEnabled = true;
            if (!true) {
                throw new AssertionError();
            }
        }
        if (assertsEnabled) {
            for (int i = 0; i < vector.length(); ++i) {
                assert (Float.isFinite(vector.get(i))) : "vector " + String.valueOf(vector) + " contains non-finite value";
            }
        }
    }

    private void updateCentroidsUnweighted() {
        for (int i = 0; i < this.k; ++i) {
            int denom = this.centroidDenoms[i];
            if (denom == 0) {
                this.initializeCentroidToRandomPoint(i);
                continue;
            }
            VectorFloat<?> centroid = this.centroidNums[i].copy();
            VectorUtil.scale(centroid, 1.0f / (float)this.centroidDenoms[i]);
            this.centroids.copyFrom(centroid, 0, i * centroid.length(), centroid.length());
        }
    }

    private void initializeCentroidToRandomPoint(int i) {
        ThreadLocalRandom random = ThreadLocalRandom.current();
        this.centroids.copyFrom(this.points[random.nextInt(this.points.length)], 0, i * this.points[0].length(), this.points[0].length());
    }

    private void updateCentroidsAnisotropic() {
        int i;
        int dimensions = this.points[0].length();
        float pcm = KMeansPlusPlusClusterer.computeParallelCostMultiplier(this.anisotropicThreshold, dimensions);
        float orthogonalCostMultiplier = 1.0f / pcm;
        HashMap pointsByCluster = new HashMap();
        for (i = 0; i < this.assignments.length; ++i) {
            pointsByCluster.computeIfAbsent(this.assignments[i], __ -> new ArrayList()).add(i);
        }
        for (i = 0; i < this.k; ++i) {
            List L = pointsByCluster.getOrDefault(i, List.of());
            if (L.isEmpty()) {
                this.initializeCentroidToRandomPoint(i);
                continue;
            }
            VectorFloat<?> mean = vectorTypeSupport.createFloatVector(dimensions);
            Matrix outerProdSums = new Matrix(dimensions, dimensions);
            Iterator iterator = L.iterator();
            while (iterator.hasNext()) {
                int j = (Integer)iterator.next();
                VectorFloat<?> point = this.points[j];
                VectorUtil.addInPlace(mean, point);
                float denom = VectorUtil.dotProduct(point, point);
                if (!(denom > 0.0f)) continue;
                Matrix op = Matrix.outerProduct(point, point);
                op.scale(1.0f / denom);
                outerProdSums.addInPlace(op);
            }
            outerProdSums.scale((1.0f - orthogonalCostMultiplier) / (float)L.size());
            VectorUtil.scale(mean, 1.0f / (float)L.size());
            for (int j = 0; j < dimensions; ++j) {
                outerProdSums.addTo(j, j, orthogonalCostMultiplier);
            }
            Matrix invertedMatrix = outerProdSums.invert();
            this.centroids.copyFrom(invertedMatrix.multiply(mean), 0, i * dimensions, dimensions);
        }
    }

    public static VectorFloat<?> centroidOf(List<VectorFloat<?>> points) {
        if (points.isEmpty()) {
            throw new IllegalArgumentException("Can't compute centroid of empty points list");
        }
        VectorFloat<?> centroid = VectorUtil.sum(points);
        VectorUtil.scale(centroid, 1.0f / (float)points.size());
        return centroid;
    }

    public VectorFloat<?> getCentroids() {
        return this.centroids;
    }
}

