/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.BBDTree;
import smile.clustering.ClusteringDistance;
import smile.clustering.PartitionClustering;
import smile.math.Math;
import smile.util.MulticoreExecutor;

public class KMeans
extends PartitionClustering<double[]>
implements Serializable {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = LoggerFactory.getLogger(KMeans.class);
    double distortion;
    double[][] centroids;

    KMeans() {
    }

    public double distortion() {
        return this.distortion;
    }

    public double[][] centroids() {
        return this.centroids;
    }

    @Override
    public int predict(double[] x) {
        double minDist = Double.MAX_VALUE;
        int bestCluster = 0;
        for (int i = 0; i < this.k; ++i) {
            double dist = Math.squaredDistance((double[])x, (double[])this.centroids[i]);
            if (!(dist < minDist)) continue;
            minDist = dist;
            bestCluster = i;
        }
        return bestCluster;
    }

    public KMeans(double[][] data, int k) {
        this(data, k, 100);
    }

    public KMeans(double[][] data, int k, int maxIter) {
        this(new BBDTree(data), data, k, maxIter);
    }

    KMeans(BBDTree bbd, double[][] data, int k, int maxIter) {
        int j;
        int i;
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + k);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int n = data.length;
        int d = data[0].length;
        this.k = k;
        this.distortion = Double.MAX_VALUE;
        this.y = KMeans.seed(data, k, ClusteringDistance.EUCLIDEAN);
        this.size = new int[k];
        this.centroids = new double[k][d];
        for (i = 0; i < n; ++i) {
            int n2 = this.y[i];
            this.size[n2] = this.size[n2] + 1;
        }
        for (i = 0; i < n; ++i) {
            for (j = 0; j < d; ++j) {
                double[] dArray = this.centroids[this.y[i]];
                int n3 = j;
                dArray[n3] = dArray[n3] + data[i][j];
            }
        }
        for (i = 0; i < k; ++i) {
            j = 0;
            while (j < d) {
                double[] dArray = this.centroids[i];
                int n4 = j++;
                dArray[n4] = dArray[n4] / (double)this.size[i];
            }
        }
        double[][] sums = new double[k][d];
        for (int iter = 1; iter <= maxIter; ++iter) {
            double dist = bbd.clustering(this.centroids, sums, this.size, this.y);
            for (int i2 = 0; i2 < k; ++i2) {
                if (this.size[i2] <= 0) continue;
                for (int j2 = 0; j2 < d; ++j2) {
                    this.centroids[i2][j2] = sums[i2][j2] / (double)this.size[i2];
                }
            }
            if (this.distortion <= dist) break;
            this.distortion = dist;
        }
    }

    public KMeans(double[][] data, int k, int maxIter, int runs) {
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + k);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        if (runs <= 0) {
            throw new IllegalArgumentException("Invalid number of runs: " + runs);
        }
        BBDTree bbd = new BBDTree(data);
        ArrayList<KMeansThread> tasks = new ArrayList<KMeansThread>();
        for (int i = 0; i < runs; ++i) {
            tasks.add(new KMeansThread(bbd, data, k, maxIter));
        }
        KMeans best = new KMeans();
        best.distortion = Double.MAX_VALUE;
        try {
            List<KMeans> clusters = MulticoreExecutor.run(tasks);
            for (KMeans kmeans : clusters) {
                if (!(kmeans.distortion < best.distortion)) continue;
                best = kmeans;
            }
        }
        catch (Exception ex) {
            logger.error("Failed to run K-Means on multi-core", (Throwable)ex);
            for (int i = 0; i < runs; ++i) {
                KMeans kmeans = KMeans.lloyd(data, k, maxIter);
                if (!(kmeans.distortion < best.distortion)) continue;
                best = kmeans;
            }
        }
        this.k = best.k;
        this.distortion = best.distortion;
        this.centroids = best.centroids;
        this.y = best.y;
        this.size = best.size;
    }

    public static KMeans lloyd(double[][] data, int k) {
        return KMeans.lloyd(data, k, 100);
    }

    public static KMeans lloyd(double[][] data, int k, int maxIter) {
        int i;
        int j;
        int i2;
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + k);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int n = data.length;
        int d = data[0].length;
        int[][] nd = new int[k][d];
        double distortion = Double.MAX_VALUE;
        int[] size = new int[k];
        double[][] centroids = new double[k][d];
        int[] y = KMeans.seed(data, k, ClusteringDistance.EUCLIDEAN_MISSING_VALUES);
        int np = MulticoreExecutor.getThreadPoolSize();
        ArrayList<LloydThread> tasks = null;
        if (n >= 1000 && np >= 2) {
            tasks = new ArrayList<LloydThread>(np + 1);
            int step = n / np;
            if (step < 100) {
                step = 100;
            }
            int start = 0;
            int end = step;
            for (i2 = 0; i2 < np - 1; ++i2) {
                tasks.add(new LloydThread(data, centroids, y, start, end));
                start += step;
                end += step;
            }
            tasks.add(new LloydThread(data, centroids, y, start, n));
        }
        for (int iter = 0; iter < maxIter; ++iter) {
            int i3;
            Arrays.fill(size, 0);
            for (i3 = 0; i3 < k; ++i3) {
                Arrays.fill(centroids[i3], 0.0);
                Arrays.fill(nd[i3], 0);
            }
            for (i3 = 0; i3 < n; ++i3) {
                int m;
                int n2 = m = y[i3];
                size[n2] = size[n2] + 1;
                for (int j2 = 0; j2 < d; ++j2) {
                    if (Double.isNaN(data[i3][j2])) continue;
                    double[] dArray = centroids[m];
                    int n3 = j2;
                    dArray[n3] = dArray[n3] + data[i3][j2];
                    int[] nArray = nd[m];
                    int n4 = j2;
                    nArray[n4] = nArray[n4] + 1;
                }
            }
            for (i3 = 0; i3 < k; ++i3) {
                for (j = 0; j < d; ++j) {
                    double[] dArray = centroids[i3];
                    int n5 = j;
                    dArray[n5] = dArray[n5] / (double)nd[i3][j];
                }
            }
            double wcss = Double.NaN;
            if (tasks != null) {
                try {
                    wcss = 0.0;
                    Iterator j2 = MulticoreExecutor.run(tasks).iterator();
                    while (j2.hasNext()) {
                        double ss = (Double)j2.next();
                        wcss += ss;
                    }
                }
                catch (Exception ex) {
                    logger.error("Failed to run K-Means on multi-core", (Throwable)ex);
                    wcss = Double.NaN;
                }
            }
            if (Double.isNaN(wcss)) {
                wcss = 0.0;
                for (i2 = 0; i2 < n; ++i2) {
                    double nearest = Double.MAX_VALUE;
                    for (int j3 = 0; j3 < k; ++j3) {
                        double dist = KMeans.squaredDistance(data[i2], centroids[j3]);
                        if (!(nearest > dist)) continue;
                        y[i2] = j3;
                        nearest = dist;
                    }
                    wcss += nearest;
                }
            }
            if (distortion <= wcss) break;
            distortion = wcss;
        }
        Arrays.fill(size, 0);
        for (i = 0; i < k; ++i) {
            Arrays.fill(centroids[i], 0.0);
            Arrays.fill(nd[i], 0);
        }
        for (i = 0; i < n; ++i) {
            int m;
            int n6 = m = y[i];
            size[n6] = size[n6] + 1;
            for (j = 0; j < d; ++j) {
                if (Double.isNaN(data[i][j])) continue;
                double[] dArray = centroids[m];
                int n7 = j;
                dArray[n7] = dArray[n7] + data[i][j];
                int[] nArray = nd[m];
                int n8 = j;
                nArray[n8] = nArray[n8] + 1;
            }
        }
        for (i = 0; i < k; ++i) {
            for (int j4 = 0; j4 < d; ++j4) {
                double[] dArray = centroids[i];
                int n9 = j4;
                dArray[n9] = dArray[n9] / (double)nd[i][j4];
            }
        }
        KMeans kmeans = new KMeans();
        kmeans.k = k;
        kmeans.distortion = distortion;
        kmeans.size = size;
        kmeans.centroids = centroids;
        kmeans.y = y;
        return kmeans;
    }

    public static KMeans lloyd(double[][] data, int k, int maxIter, int runs) {
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + k);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        if (runs <= 0) {
            throw new IllegalArgumentException("Invalid number of runs: " + runs);
        }
        KMeans best = KMeans.lloyd(data, k, maxIter);
        for (int i = 1; i < runs; ++i) {
            KMeans kmeans = KMeans.lloyd(data, k, maxIter);
            if (!(kmeans.distortion < best.distortion)) continue;
            best = kmeans;
        }
        return best;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("K-Means distortion: %.5f%n", this.distortion));
        sb.append(String.format("Clusters of %d data points of dimension %d:%n", this.y.length, this.centroids[0].length));
        for (int i = 0; i < this.k; ++i) {
            int r = (int)Math.round((double)(1000.0 * (double)this.size[i] / (double)this.y.length));
            sb.append(String.format("%3d\t%5d (%2d.%1d%%)%n", i, this.size[i], r / 10, r % 10));
        }
        return sb.toString();
    }

    static class LloydThread
    implements Callable<Double> {
        final int start;
        final int end;
        final double[][] data;
        final int k;
        final double[][] centroids;
        int[] y;

        LloydThread(double[][] data, double[][] centroids, int[] y, int start, int end) {
            this.data = data;
            this.k = centroids.length;
            this.y = y;
            this.centroids = centroids;
            this.start = start;
            this.end = end;
        }

        @Override
        public Double call() {
            double wcss = 0.0;
            for (int i = this.start; i < this.end; ++i) {
                double nearest = Double.MAX_VALUE;
                for (int j = 0; j < this.k; ++j) {
                    double dist = PartitionClustering.squaredDistance(this.data[i], this.centroids[j]);
                    if (!(nearest > dist)) continue;
                    this.y[i] = j;
                    nearest = dist;
                }
                wcss += nearest;
            }
            return wcss;
        }
    }

    static class KMeansThread
    implements Callable<KMeans> {
        final BBDTree bbd;
        final double[][] data;
        final int k;
        final int maxIter;

        KMeansThread(BBDTree bbd, double[][] data, int k, int maxIter) {
            this.bbd = bbd;
            this.data = data;
            this.k = k;
            this.maxIter = maxIter;
        }

        @Override
        public KMeans call() {
            return new KMeans(this.bbd, this.data, this.k, this.maxIter);
        }
    }
}

