/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.Arrays;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.IntStream;
import smile.clustering.PartitionClustering;

public abstract class CentroidClustering<T, U>
extends PartitionClustering
implements Comparable<CentroidClustering<T, U>> {
    private static final long serialVersionUID = 2L;
    public final double distortion;
    public final T[] centroids;

    public CentroidClustering(double distortion, T[] centroids, int[] y) {
        super(centroids.length, y);
        this.distortion = distortion;
        this.centroids = centroids;
    }

    @Override
    public int compareTo(CentroidClustering<T, U> o) {
        return Double.compare(this.distortion, o.distortion);
    }

    protected abstract double distance(T var1, U var2);

    public int predict(U x) {
        double nearest = Double.MAX_VALUE;
        int label = 0;
        for (int i = 0; i < this.k; ++i) {
            double dist = this.distance(this.centroids[i], x);
            if (!(dist < nearest)) continue;
            nearest = dist;
            label = i;
        }
        return label;
    }

    @Override
    public String toString() {
        return String.format("Cluster distortion: %.5f%n", this.distortion) + super.toString();
    }

    static <T> double assign(int[] y, T[] data, T[] centroids, ToDoubleBiFunction<T, T> distance) {
        int k = centroids.length;
        double wcss = IntStream.range(0, data.length).parallel().mapToDouble(i -> {
            double nearest = Double.MAX_VALUE;
            for (int j = 0; j < k; ++j) {
                double dist = distance.applyAsDouble(data[i], centroids[j]);
                if (!(nearest > dist)) continue;
                nearest = dist;
                y[i] = j;
            }
            return nearest;
        }).sum();
        return wcss;
    }

    static void updateCentroids(double[][] centroids, double[][] data, int[] y, int[] size) {
        int n = data.length;
        int k = centroids.length;
        int d = centroids[0].length;
        Arrays.fill(size, 0);
        IntStream.range(0, k).parallel().forEach(cluster -> {
            Arrays.fill(centroids[cluster], 0.0);
            for (int i = 0; i < n; ++i) {
                if (y[i] != cluster) continue;
                int n2 = cluster;
                size[n2] = size[n2] + 1;
                for (int j = 0; j < d; ++j) {
                    double[] dArray = centroids[cluster];
                    int n3 = j;
                    dArray[n3] = dArray[n3] + data[i][j];
                }
            }
            int j = 0;
            while (j < d) {
                double[] dArray = centroids[cluster];
                int n4 = j++;
                dArray[n4] = dArray[n4] / (double)size[cluster];
            }
        });
    }

    static void updateCentroidsWithMissingValues(double[][] centroids, double[][] data, int[] y, int[] size, int[][] notNaN) {
        int n = data.length;
        int k = centroids.length;
        int d = centroids[0].length;
        IntStream.range(0, k).parallel().forEach(cluster -> {
            Arrays.fill(centroids[cluster], 0.0);
            Arrays.fill(notNaN[cluster], 0);
            for (int i = 0; i < n; ++i) {
                if (y[i] != cluster) continue;
                int n2 = cluster;
                size[n2] = size[n2] + 1;
                for (int j = 0; j < d; ++j) {
                    if (Double.isNaN(data[i][j])) continue;
                    double[] dArray = centroids[cluster];
                    int n3 = j;
                    dArray[n3] = dArray[n3] + data[i][j];
                    int[] nArray = notNaN[cluster];
                    int n4 = j;
                    nArray[n4] = nArray[n4] + 1;
                }
            }
            for (int j = 0; j < d; ++j) {
                double[] dArray = centroids[cluster];
                int n5 = j;
                dArray[n5] = dArray[n5] / (double)notNaN[cluster][j];
            }
        });
    }
}

