/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.vector.VectorUtil;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;

public class KMeansPlusPlusClusterer {
    private final int k;
    private final float[][] points;
    private final int[] assignments;
    private final float[][] centroids;
    private final int[] centroidDenoms;
    private final float[][] centroidNums;

    public KMeansPlusPlusClusterer(float[][] points, int k) {
        this(points, KMeansPlusPlusClusterer.chooseInitialCentroids(points, k));
    }

    public KMeansPlusPlusClusterer(float[][] points, float[][] centroids) {
        this.points = points;
        this.k = centroids.length;
        this.centroids = (float[][])Arrays.stream(centroids).map(rec$ -> (float[])((float[])rec$).clone()).toArray(x$0 -> new float[x$0][]);
        this.centroidDenoms = new int[this.k];
        this.centroidNums = new float[this.k][points[0].length];
        this.assignments = new int[points.length];
        this.initializeAssignedPoints();
    }

    public float[][] cluster(int maxIterations) {
        int changedCount;
        for (int i = 0; i < maxIterations && !((double)(changedCount = this.clusterOnce()) <= 0.01 * (double)this.points.length); ++i) {
        }
        return this.centroids;
    }

    public int clusterOnce() {
        this.updateCentroids();
        return this.updateAssignedPoints();
    }

    private static float[][] chooseInitialCentroids(float[][] points, int k) {
        int i;
        if (k <= 0) {
            throw new IllegalArgumentException("Number of clusters must be positive.");
        }
        if (k > points.length) {
            throw new IllegalArgumentException(String.format("Number of clusters %d cannot exceed number of points %d", k, points.length));
        }
        ThreadLocalRandom random = ThreadLocalRandom.current();
        float[][] centroids = new float[k][];
        float[] distances = new float[points.length];
        Arrays.fill(distances, Float.MAX_VALUE);
        float[] firstCentroid = points[random.nextInt(points.length)];
        centroids[0] = firstCentroid;
        for (i = 0; i < points.length; ++i) {
            float distance1 = VectorUtil.squareDistance(points[i], firstCentroid);
            distances[i] = Math.min(distances[i], distance1);
        }
        for (i = 1; i < k; ++i) {
            float totalDistance = 0.0f;
            for (float distance : distances) {
                totalDistance += distance;
            }
            float r = random.nextFloat() * totalDistance;
            int selectedIdx = -1;
            for (int j = 0; j < distances.length; ++j) {
                if (!((double)(r -= distances[j]) < 1.0E-6)) continue;
                selectedIdx = j;
                break;
            }
            if (selectedIdx == -1) {
                selectedIdx = random.nextInt(points.length);
            }
            float[] nextCentroid = points[selectedIdx];
            centroids[i] = nextCentroid;
            for (int j = 0; j < points.length; ++j) {
                float newDistance = VectorUtil.squareDistance(points[j], nextCentroid);
                distances[j] = Math.min(distances[j], newDistance);
            }
        }
        for (float[] centroid : centroids) {
            KMeansPlusPlusClusterer.assertFinite(centroid);
        }
        return centroids;
    }

    private void initializeAssignedPoints() {
        for (int i = 0; i < this.points.length; ++i) {
            float[] point = this.points[i];
            int newAssignment = this.getNearestCluster(point);
            this.centroidDenoms[newAssignment] = this.centroidDenoms[newAssignment] + 1;
            VectorUtil.addInPlace(this.centroidNums[newAssignment], point);
            this.assignments[i] = newAssignment;
        }
    }

    private int updateAssignedPoints() {
        int changedCount = 0;
        for (int i = 0; i < this.points.length; ++i) {
            float[] point = this.points[i];
            int oldAssignment = this.assignments[i];
            int newAssignment = this.getNearestCluster(point);
            if (newAssignment == oldAssignment) continue;
            this.centroidDenoms[oldAssignment] = this.centroidDenoms[oldAssignment] - 1;
            VectorUtil.subInPlace(this.centroidNums[oldAssignment], point);
            this.centroidDenoms[newAssignment] = this.centroidDenoms[newAssignment] + 1;
            VectorUtil.addInPlace(this.centroidNums[newAssignment], point);
            this.assignments[i] = newAssignment;
            ++changedCount;
        }
        return changedCount;
    }

    private int getNearestCluster(float[] point) {
        float minDistance = Float.MAX_VALUE;
        int nearestCluster = 0;
        for (int i = 0; i < this.k; ++i) {
            float distance = VectorUtil.squareDistance(point, this.centroids[i]);
            if (!(distance < minDistance)) continue;
            minDistance = distance;
            nearestCluster = i;
        }
        return nearestCluster;
    }

    private static void assertFinite(float[] vector) {
        boolean assertsEnabled = false;
        if (!$assertionsDisabled) {
            assertsEnabled = true;
            if (!true) {
                throw new AssertionError();
            }
        }
        if (assertsEnabled) {
            for (float v : vector) {
                assert (Float.isFinite(v)) : "vector " + Arrays.toString(vector) + " contains non-finite value";
            }
        }
    }

    private void updateCentroids() {
        ThreadLocalRandom random = ThreadLocalRandom.current();
        for (int i = 0; i < this.centroids.length; ++i) {
            int denom = this.centroidDenoms[i];
            if (denom == 0) {
                this.centroids[i] = this.points[random.nextInt(this.points.length)];
                continue;
            }
            this.centroids[i] = Arrays.copyOf(this.centroidNums[i], this.centroidNums[i].length);
            VectorUtil.scale(this.centroids[i], 1.0f / (float)this.centroidDenoms[i]);
        }
    }

    public static float[] centroidOf(List<float[]> points) {
        if (points.isEmpty()) {
            throw new IllegalArgumentException("Can't compute centroid of empty points list");
        }
        float[] centroid = VectorUtil.sum(points);
        VectorUtil.scale(centroid, 1.0f / (float)points.size());
        return centroid;
    }

    public float[][] getCentroids() {
        return this.centroids;
    }
}

