/*
 * Decompiled with CFR 0.152.
 */
package com.github.jbellis.jvector.pq;

import com.github.jbellis.jvector.vector.VectorUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.function.BiFunction;

public class KMeansPlusPlusClusterer {
    private final int k;
    private final BiFunction<float[], float[], Float> distanceFunction;
    private final Random random;
    private final List<float[]>[] clusterPoints;
    private final float[][] centroidDistances;
    private final float[][] points;
    private final int[] assignments;
    private final float[][] centroids;

    public KMeansPlusPlusClusterer(float[][] points, int k, BiFunction<float[], float[], Float> distanceFunction) {
        if (k <= 0) {
            throw new IllegalArgumentException("Number of clusters must be positive.");
        }
        if (k > points.length) {
            throw new IllegalArgumentException(String.format("Number of clusters %d cannot exceed number of points %d", k, points.length));
        }
        this.points = points;
        this.k = k;
        this.distanceFunction = distanceFunction;
        this.random = new Random();
        this.clusterPoints = new List[k];
        for (int i = 0; i < k; ++i) {
            this.clusterPoints[i] = new ArrayList<float[]>();
        }
        this.centroidDistances = new float[k][k];
        this.centroids = this.chooseInitialCentroids(points);
        this.updateCentroidDistances();
        this.assignments = new int[points.length];
        this.assignPointsToClusters();
    }

    public float[][] cluster(int maxIterations) {
        int changedCount;
        for (int i = 0; i < maxIterations && !((double)(changedCount = this.clusterOnce()) <= 0.01 * (double)this.points.length); ++i) {
        }
        return this.centroids;
    }

    public int clusterOnce() {
        for (int j = 0; j < this.k; ++j) {
            this.centroids[j] = this.clusterPoints[j].isEmpty() ? this.points[this.random.nextInt(this.points.length)] : KMeansPlusPlusClusterer.centroidOf(this.clusterPoints[j]);
        }
        int changedCount = this.assignPointsToClusters();
        this.updateCentroidDistances();
        return changedCount;
    }

    private void updateCentroidDistances() {
        for (int m = 0; m < this.k; ++m) {
            for (int n = m + 1; n < this.k; ++n) {
                float distance;
                this.centroidDistances[m][n] = distance = this.distanceFunction.apply(this.centroids[m], this.centroids[n]).floatValue();
                this.centroidDistances[n][m] = distance;
            }
        }
    }

    private float[][] chooseInitialCentroids(float[][] points) {
        int i;
        float[][] centroids = new float[this.k][];
        float[] distances = new float[points.length];
        Arrays.fill(distances, Float.MAX_VALUE);
        float[] firstCentroid = points[this.random.nextInt(points.length)];
        centroids[0] = firstCentroid;
        for (i = 0; i < points.length; ++i) {
            float distance1 = this.distanceFunction.apply(points[i], firstCentroid).floatValue();
            distances[i] = Math.min(distances[i], distance1);
        }
        for (i = 1; i < this.k; ++i) {
            float totalDistance = 0.0f;
            for (float distance : distances) {
                totalDistance += distance;
            }
            float r = this.random.nextFloat() * totalDistance;
            int selectedIdx = -1;
            for (int j = 0; j < distances.length; ++j) {
                if (!((double)(r -= distances[j]) < 1.0E-6)) continue;
                selectedIdx = j;
                break;
            }
            if (selectedIdx == -1) {
                selectedIdx = this.random.nextInt(points.length);
            }
            float[] nextCentroid = points[selectedIdx];
            centroids[i] = nextCentroid;
            for (int j = 0; j < points.length; ++j) {
                float newDistance = this.distanceFunction.apply(points[j], nextCentroid).floatValue();
                distances[j] = Math.min(distances[j], newDistance);
            }
        }
        return centroids;
    }

    private int assignPointsToClusters() {
        int changedCount = 0;
        for (List<float[]> cluster : this.clusterPoints) {
            cluster.clear();
        }
        for (int i = 0; i < this.points.length; ++i) {
            float[] point = this.points[i];
            int clusterIndex = this.getNearestCluster(point, this.centroids);
            if (this.assignments[i] != clusterIndex) {
                ++changedCount;
            }
            this.clusterPoints[clusterIndex].add(point);
            this.assignments[i] = clusterIndex;
        }
        return changedCount;
    }

    private int getNearestCluster(float[] point, float[][] centroids) {
        float minDistance = Float.MAX_VALUE;
        int nearestCluster = 0;
        for (int i = 0; i < this.k; ++i) {
            float distance;
            float potentialMinDistance;
            if (i != nearestCluster && (potentialMinDistance = Math.abs(minDistance - this.centroidDistances[nearestCluster][i])) >= minDistance || !((distance = this.distanceFunction.apply(point, centroids[i]).floatValue()) < minDistance)) continue;
            minDistance = distance;
            nearestCluster = i;
        }
        return nearestCluster;
    }

    public static float[] centroidOf(List<float[]> points) {
        if (points.isEmpty()) {
            throw new IllegalArgumentException("Can't compute centroid of empty points list");
        }
        float[] centroid = VectorUtil.sum(points);
        VectorUtil.divInPlace(centroid, points.size());
        return centroid;
    }
}

