/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.Arrays;
import smile.math.MathEx;

public class BBDTree {
    private final Node root;
    private final int[] index;

    public BBDTree(double[][] data) {
        int n = data.length;
        this.index = new int[n];
        for (int i = 0; i < n; ++i) {
            this.index[i] = i;
        }
        this.root = this.buildNode(data, 0, n);
    }

    private Node buildNode(double[][] data, int begin, int end) {
        int i;
        int d = data[0].length;
        Node node = new Node(d);
        node.size = end - begin;
        node.index = begin;
        double[] lowerBound = new double[d];
        double[] upperBound = new double[d];
        for (i = 0; i < d; ++i) {
            lowerBound[i] = data[this.index[begin]][i];
            upperBound[i] = data[this.index[begin]][i];
        }
        for (i = begin + 1; i < end; ++i) {
            for (int j = 0; j < d; ++j) {
                double c = data[this.index[i]][j];
                if (lowerBound[j] > c) {
                    lowerBound[j] = c;
                }
                if (!(upperBound[j] < c)) continue;
                upperBound[j] = c;
            }
        }
        double maxRadius = -1.0;
        int splitIndex = -1;
        for (int i2 = 0; i2 < d; ++i2) {
            node.center[i2] = (lowerBound[i2] + upperBound[i2]) / 2.0;
            node.radius[i2] = (upperBound[i2] - lowerBound[i2]) / 2.0;
            if (!(node.radius[i2] > maxRadius)) continue;
            maxRadius = node.radius[i2];
            splitIndex = i2;
        }
        if (maxRadius < 1.0E-10) {
            node.upper = null;
            node.lower = null;
            System.arraycopy(data[this.index[begin]], 0, node.sum, 0, d);
            if (end > begin + 1) {
                int len = end - begin;
                int i3 = 0;
                while (i3 < d) {
                    int n = i3++;
                    node.sum[n] = node.sum[n] * (double)len;
                }
            }
            node.cost = 0.0;
            return node;
        }
        double splitCutoff = node.center[splitIndex];
        int i1 = begin;
        int i2 = end - 1;
        int size = 0;
        while (i1 <= i2) {
            boolean i2Good;
            boolean i1Good = data[this.index[i1]][splitIndex] < splitCutoff;
            boolean bl = i2Good = data[this.index[i2]][splitIndex] >= splitCutoff;
            if (!i1Good && !i2Good) {
                int temp = this.index[i1];
                this.index[i1] = this.index[i2];
                this.index[i2] = temp;
                i2Good = true;
                i1Good = true;
            }
            if (i1Good) {
                ++i1;
                ++size;
            }
            if (!i2Good) continue;
            --i2;
        }
        node.lower = this.buildNode(data, begin, begin + size);
        node.upper = this.buildNode(data, begin + size, end);
        for (int i4 = 0; i4 < d; ++i4) {
            node.sum[i4] = node.lower.sum[i4] + node.upper.sum[i4];
        }
        double[] mean = new double[d];
        for (int i5 = 0; i5 < d; ++i5) {
            mean[i5] = node.sum[i5] / (double)node.size;
        }
        node.cost = this.getNodeCost(node.lower, mean) + this.getNodeCost(node.upper, mean);
        return node;
    }

    private double getNodeCost(Node node, double[] center) {
        int d = center.length;
        double scatter = 0.0;
        for (int i = 0; i < d; ++i) {
            double x = node.sum[i] / (double)node.size - center[i];
            scatter += x * x;
        }
        return node.cost + (double)node.size * scatter;
    }

    public double clustering(double[][] centroids, double[][] sum, int[] size, int[] y) {
        int k = centroids.length;
        Arrays.fill(size, 0);
        int[] candidates = new int[k];
        for (int i = 0; i < k; ++i) {
            candidates[i] = i;
            Arrays.fill(sum[i], 0.0);
        }
        double wcss = this.filter(this.root, centroids, candidates, k, sum, size, y);
        int d = centroids[0].length;
        for (int i = 0; i < k; ++i) {
            if (size[i] <= 0) continue;
            for (int j = 0; j < d; ++j) {
                centroids[i][j] = sum[i][j] / (double)size[i];
            }
        }
        return wcss;
    }

    private double filter(Node node, double[][] centroids, int[] candidates, int k, double[][] sum, int[] size, int[] y) {
        int i;
        int d = centroids[0].length;
        double minDist = MathEx.squaredDistance((double[])node.center, (double[])centroids[candidates[0]]);
        int closest = candidates[0];
        for (i = 1; i < k; ++i) {
            double dist = MathEx.squaredDistance((double[])node.center, (double[])centroids[candidates[i]]);
            if (!(dist < minDist)) continue;
            minDist = dist;
            closest = candidates[i];
        }
        if (node.lower != null) {
            int[] newCandidates = new int[k];
            int k2 = 0;
            for (int i2 = 0; i2 < k; ++i2) {
                if (this.prune(node.center, node.radius, centroids, closest, candidates[i2])) continue;
                newCandidates[k2++] = candidates[i2];
            }
            if (k2 > 1) {
                return this.filter(node.lower, centroids, newCandidates, k2, sum, size, y) + this.filter(node.upper, centroids, newCandidates, k2, sum, size, y);
            }
        }
        for (i = 0; i < d; ++i) {
            double[] dArray = sum[closest];
            int n = i;
            dArray[n] = dArray[n] + node.sum[i];
        }
        int n = closest;
        size[n] = size[n] + node.size;
        int last = node.index + node.size;
        for (int i3 = node.index; i3 < last; ++i3) {
            y[this.index[i3]] = closest;
        }
        return this.getNodeCost(node, centroids[closest]);
    }

    private boolean prune(double[] center, double[] radius, double[][] centroids, int bestIndex, int testIndex) {
        if (bestIndex == testIndex) {
            return false;
        }
        int d = centroids[0].length;
        double[] best = centroids[bestIndex];
        double[] test = centroids[testIndex];
        double lhs = 0.0;
        double rhs = 0.0;
        for (int i = 0; i < d; ++i) {
            double diff = test[i] - best[i];
            lhs += diff * diff;
            if (diff > 0.0) {
                rhs += (center[i] + radius[i] - best[i]) * diff;
                continue;
            }
            rhs += (center[i] - radius[i] - best[i]) * diff;
        }
        return lhs >= 2.0 * rhs;
    }

    static class Node {
        int size;
        int index;
        double[] center;
        double[] radius;
        double[] sum;
        double cost;
        Node lower;
        Node upper;

        Node(int d) {
            this.center = new double[d];
            this.radius = new double[d];
            this.sum = new double[d];
        }
    }
}

