/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.cluster;

import com.aliasi.cluster.Clusterer;
import com.aliasi.io.LogLevel;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.stats.Statistics;
import com.aliasi.symbol.MapSymbolTable;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.ObjectToDoubleMap;
import com.aliasi.util.SmallSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Random;
import java.util.Set;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class KMeansClusterer<E>
implements Clusterer<E> {
    final FeatureExtractor<E> mFeatureExtractor;
    final int mMaxNumClusters;
    final int mMaxEpochs;
    final boolean mKMeansPlusPlus;
    final double mMinRelativeImprovement;

    KMeansClusterer(FeatureExtractor<E> featureExtractor, int numClusters, int maxEpochs) {
        this(featureExtractor, numClusters, maxEpochs, false, 0.0);
    }

    public KMeansClusterer(FeatureExtractor<E> featureExtractor, int numClusters, int maxEpochs, boolean kMeansPlusPlus, double minImprovement) {
        if (numClusters < 1) {
            String msg = "Number of clusters must be positive. Found numClusters=" + numClusters;
            throw new IllegalArgumentException(msg);
        }
        if (maxEpochs < 0) {
            String msg = "Number of epochs must be non-negative. Found maxEpochs=" + maxEpochs;
            throw new IllegalArgumentException(msg);
        }
        if (minImprovement < 0.0 || Double.isNaN(minImprovement)) {
            String msg = "Mimium improvement must be non-negative. Found minImprovement=" + minImprovement;
            throw new IllegalArgumentException(msg);
        }
        this.mFeatureExtractor = featureExtractor;
        this.mMaxNumClusters = numClusters;
        this.mMaxEpochs = maxEpochs;
        this.mKMeansPlusPlus = kMeansPlusPlus;
        this.mMinRelativeImprovement = minImprovement;
    }

    public FeatureExtractor<E> featureExtractor() {
        return this.mFeatureExtractor;
    }

    public int numClusters() {
        return this.mMaxNumClusters;
    }

    public int maxEpochs() {
        return this.mMaxEpochs;
    }

    @Override
    public Set<Set<E>> cluster(Set<? extends E> elementSet) {
        return this.cluster(elementSet, new Random(), null);
    }

    public Set<Set<E>> cluster(Set<? extends E> elementSet, Random random, Reporter reporter) {
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        int numElements = elementSet.size();
        int numClusters = this.mMaxNumClusters;
        reporter.report(LogLevel.INFO, "#Elements=" + numElements);
        reporter.report(LogLevel.INFO, "#Clusters=" + numClusters);
        if (numElements <= numClusters) {
            reporter.report(LogLevel.INFO, "Returning trivial clustering due to #elements < #clusters");
            return this.trivialClustering(elementSet);
        }
        Object[] elements = elementSet.toArray(new Object[0]);
        reporter.report(LogLevel.DEBUG, "Converting inputs to sparse vectors");
        int[][] featuress = new int[numElements][];
        double[][] valss = new double[numElements][];
        double[] eltSqLengths = new double[numElements];
        MapSymbolTable symTab = this.toVectors(elements, featuress, valss, eltSqLengths);
        int numDims = symTab.numSymbols();
        reporter.report(LogLevel.INFO, "#Dimensions=" + numDims);
        double[][] centroidss = new double[numClusters][numDims];
        int[] closestCenters = new int[numElements];
        double[] sqDistToCenters = new double[numElements];
        reporter.report(LogLevel.INFO, "K-Means++ Initialization");
        this.kmeansPlusPlusInit(featuress, valss, eltSqLengths, closestCenters, centroidss, random);
        return this.kMeansEpochs(elements, eltSqLengths, centroidss, featuress, valss, sqDistToCenters, closestCenters, this.mMaxEpochs, reporter);
    }

    public double minRelativeImprovement() {
        return this.mMinRelativeImprovement;
    }

    public Set<Set<E>> recluster(Set<Set<E>> initialClustering, Set<E> unclusteredElements, Reporter reporter) {
        return this.recluster(initialClustering, unclusteredElements, this.mMaxEpochs, reporter);
    }

    Set<Set<E>> recluster(Set<Set<E>> clustering, int maxEpochs) {
        return this.recluster(clustering, SmallSet.create(), maxEpochs, null);
    }

    private Set<Set<E>> recluster(Set<Set<E>> clustering, Set<E> unclusteredElements, int maxEpochs, Reporter reporter) {
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        reporter.report(LogLevel.INFO, "Reclustering");
        int numClusters = clustering.size();
        reporter.report(LogLevel.INFO, "# Clusters=" + numClusters);
        HashSet<E> elementSet = new HashSet<E>();
        for (Set<E> cluster : clustering) {
            for (E e : cluster) {
                if (elementSet.add(e)) continue;
                String msg = "An element must not be in two clusters. Found an element in two clusters. Element=" + e;
                throw new IllegalArgumentException(msg);
            }
        }
        int numClusteredElements = elementSet.size();
        for (E e : unclusteredElements) {
            if (elementSet.add(e)) continue;
            String msg = "An element may not be in a cluster and unclustered. Found unclustered element in a cluster. Element=" + e;
            throw new IllegalArgumentException(msg);
        }
        int numElements = elementSet.size();
        reporter.report(LogLevel.INFO, "# Clustered Elements=" + numClusteredElements);
        reporter.report(LogLevel.INFO, "# Unclustered Elements=" + unclusteredElements.size());
        reporter.report(LogLevel.INFO, "# Elements Total=" + numElements);
        Object[] elements = new Object[numElements];
        int i = 0;
        for (Set<E> cluster : clustering) {
            for (E e : cluster) {
                elements[i++] = e;
            }
        }
        for (Set<E> e : unclusteredElements) {
            elements[i++] = e;
        }
        reporter.report(LogLevel.DEBUG, "Converting to vectors");
        int[][] featuress = new int[numElements][];
        double[][] valss = new double[numElements][];
        double[] eltSqLengths = new double[numElements];
        MapSymbolTable symTab = this.toVectors(elements, featuress, valss, eltSqLengths);
        int numDims = symTab.numSymbols();
        reporter.report(LogLevel.INFO, "#Dimensions=" + numDims);
        double[][] centroidss = new double[numClusters][numDims];
        int[] closestCenters = new int[numElements];
        i = 0;
        int k = 0;
        for (Set<E> cluster : clustering) {
            double[] centroidK = centroidss[k];
            for (E e : cluster) {
                closestCenters[i] = k;
                KMeansClusterer.increment(centroidK, featuress[i], valss[i]);
                ++i;
            }
            ++k;
        }
        double[] sqDistToCenters = new double[numElements];
        Arrays.fill(sqDistToCenters, Double.POSITIVE_INFINITY);
        for (k = 0; k < numClusters; ++k) {
            double[] centroidK = centroidss[k];
            double centroidSqLength = KMeansClusterer.selfProduct(centroidss[k]);
            for (i = 0; i < numElements; ++i) {
                double sqDistToCenter = centroidSqLength + eltSqLengths[i] - 2.0 * KMeansClusterer.product(centroidK, featuress[i], valss[i]);
                if (!(sqDistToCenter < sqDistToCenters[i])) continue;
                sqDistToCenters[i] = sqDistToCenter;
                closestCenters[i] = k;
            }
        }
        for (double[] centroid : centroidss) {
            Arrays.fill(centroid, 0.0);
        }
        this.setCentroids(centroidss, featuress, valss, closestCenters);
        return this.kMeansEpochs(elements, eltSqLengths, centroidss, featuress, valss, sqDistToCenters, closestCenters, maxEpochs, reporter);
    }

    private Set<Set<E>> kMeansEpochs(E[] elements, double[] eltSqLengths, double[][] centroidss, int[][] featuress, double[][] valss, double[] sqDistToCenters, int[] closestCenters, int maxEpochs, Reporter reporter) {
        int numClusters = centroidss.length;
        int numDims = centroidss[0].length;
        int numElements = elements.length;
        double[] centroidSqLengths = KMeansClusterer.centroidSqLengths(centroidss);
        boolean[] lastCentroidChanges = KMeansClusterer.createBooleanArray(numClusters, true);
        int[] changedClusters = new int[numClusters];
        int[] counts = new int[numClusters];
        double lastError = Double.POSITIVE_INFINITY;
        for (int epoch = 0; epoch < maxEpochs; ++epoch) {
            int k;
            reporter.report(LogLevel.DEBUG, "Epoch=" + epoch);
            boolean atLeastOneClusterChanged = false;
            int numChangedClusters = KMeansClusterer.setChangedClusters(changedClusters, lastCentroidChanges);
            reporter.report(LogLevel.DEBUG, "    #changed clusters=" + numChangedClusters);
            boolean[] centroidChanges = KMeansClusterer.createBooleanArray(numClusters, false);
            for (int i = 0; i < numElements; ++i) {
                double sqDistToCenter;
                int k2;
                int kk;
                int[] featuresI = featuress[i];
                double[] valsI = valss[i];
                double eltSqLengthI = eltSqLengths[i];
                double closestSqDistToCenter = lastCentroidChanges[closestCenters[i]] ? Double.POSITIVE_INFINITY : sqDistToCenters[i];
                int bestCenter = -1;
                for (kk = 0; kk < numChangedClusters; ++kk) {
                    k2 = changedClusters[kk];
                    sqDistToCenter = centroidSqLengths[k2] + eltSqLengthI - 2.0 * KMeansClusterer.product(centroidss[k2], featuresI, valsI);
                    if (!(sqDistToCenter < closestSqDistToCenter)) continue;
                    closestSqDistToCenter = sqDistToCenter;
                    bestCenter = k2;
                }
                if (bestCenter == -1) continue;
                if (closestSqDistToCenter > sqDistToCenters[i]) {
                    for (kk = numChangedClusters; kk < numClusters; ++kk) {
                        k2 = changedClusters[kk];
                        sqDistToCenter = centroidSqLengths[k2] + eltSqLengthI - 2.0 * KMeansClusterer.product(centroidss[k2], featuresI, valsI);
                        if (!(sqDistToCenter < closestSqDistToCenter)) continue;
                        closestSqDistToCenter = sqDistToCenter;
                        bestCenter = k2;
                    }
                }
                sqDistToCenters[i] = closestSqDistToCenter;
                if (bestCenter == closestCenters[i]) continue;
                atLeastOneClusterChanged = true;
                centroidChanges[bestCenter] = true;
                centroidChanges[closestCenters[i]] = true;
                closestCenters[i] = bestCenter;
            }
            double error = KMeansClusterer.sum(sqDistToCenters) / (double)numElements;
            reporter.report(LogLevel.DEBUG, "    avg dist to center=" + error);
            if (!atLeastOneClusterChanged) {
                reporter.report(LogLevel.INFO, "Converged by no elements changing cluster.");
                break;
            }
            double relImprovement = KMeansClusterer.relativeImprovement(lastError, error);
            if (relImprovement < this.mMinRelativeImprovement) {
                reporter.report(LogLevel.INFO, "Converged by relative improvement < threshold");
                break;
            }
            Arrays.fill(counts, 0);
            int numChangedElts = 0;
            for (k = 0; k < numClusters; ++k) {
                if (!centroidChanges[k]) continue;
                Arrays.fill(centroidss[k], 0.0);
            }
            for (int i = 0; i < numElements; ++i) {
                int closestCenterI = closestCenters[i];
                if (!centroidChanges[closestCenterI]) continue;
                KMeansClusterer.increment(centroidss[closestCenterI], featuress[i], valss[i]);
                int n = closestCenterI;
                counts[n] = counts[n] + 1;
                ++numChangedElts;
            }
            reporter.report(LogLevel.DEBUG, "    #changed elts=" + numChangedElts);
            for (k = 0; k < numClusters; ++k) {
                if (counts[k] <= 0) continue;
                double[] centroidK = centroidss[k];
                double countD = counts[k];
                double sqLength = 0.0;
                for (int d = 0; d < numDims; ++d) {
                    int n = d;
                    centroidK[n] = centroidK[n] / countD;
                    sqLength += centroidK[d] * centroidK[d];
                }
                centroidSqLengths[k] = sqLength;
            }
            lastCentroidChanges = centroidChanges;
            if (epoch != maxEpochs - 1) continue;
            reporter.report(LogLevel.INFO, "Reached max epochs. Breaking without convergence.");
        }
        reporter.report(LogLevel.DEBUG, "Constructing Result");
        ArrayList scoreMapList = new ArrayList(numClusters);
        double[] totalScores = new double[numClusters];
        for (int k = 0; k < numClusters; ++k) {
            scoreMapList.add(new ObjectToDoubleMap());
        }
        for (int i = 0; i < numElements; ++i) {
            ((ObjectToDoubleMap)scoreMapList.get(closestCenters[i])).set(elements[i], sqDistToCenters[i] == 0.0 ? -4.9E-324 : -sqDistToCenters[i]);
            int n = closestCenters[i];
            totalScores[n] = totalScores[n] - sqDistToCenters[i];
        }
        ObjectToDoubleMap clusterScores = new ObjectToDoubleMap();
        for (int k = 0; k < numClusters; ++k) {
            ObjectToDoubleMap clusterDistances = (ObjectToDoubleMap)scoreMapList.get(k);
            if (clusterDistances.isEmpty()) continue;
            LinkedHashSet cluster = new LinkedHashSet(clusterDistances.keysOrderedByValueList());
            clusterScores.set(cluster, totalScores[k] == 0.0 ? -4.9E-324 : totalScores[k] / (double)cluster.size());
        }
        LinkedHashSet<Set<E>> result = new LinkedHashSet<Set<E>>(clusterScores.keysOrderedByValueList());
        return result;
    }

    static double relativeImprovement(double x, double y) {
        return Math.abs(2.0 * (x - y) / (Math.abs(x) + Math.abs(y)));
    }

    static int setChangedClusters(int[] clusterIndexes, boolean[] changed) {
        int numChanged = 0;
        int numNotChanged = clusterIndexes.length - 1;
        int i = 0;
        while (i < changed.length) {
            int n;
            if (changed[i]) {
                n = numChanged++;
            } else {
                int n2 = numNotChanged;
                n = n2;
                numNotChanged = n2 - 1;
            }
            clusterIndexes[n] = i++;
        }
        return numChanged;
    }

    static boolean[] createBooleanArray(int length, boolean fillValue) {
        boolean[] result = new boolean[length];
        if (fillValue) {
            Arrays.fill(result, true);
        }
        return result;
    }

    private MapSymbolTable toVectors(E[] elements, int[][] featuress, double[][] valss, double[] eltSqLengths) {
        MapSymbolTable symTab = new MapSymbolTable();
        for (int i = 0; i < elements.length; ++i) {
            E e = elements[i];
            Map<String, Number> featureMap = this.mFeatureExtractor.features(e);
            featuress[i] = new int[featureMap.size()];
            valss[i] = new double[featureMap.size()];
            int j = 0;
            for (Map.Entry<String, Number> entry : featureMap.entrySet()) {
                featuress[i][j] = symTab.getOrAddSymbol(entry.getKey());
                valss[i][j] = entry.getValue().doubleValue();
                ++j;
            }
            eltSqLengths[i] = KMeansClusterer.selfProduct(valss[i]);
        }
        return symTab;
    }

    private Set<Set<E>> trivialClustering(Set<? extends E> elementSet) {
        HashSet<Set<SmallSet<E>>> clustering = new HashSet<Set<SmallSet<E>>>(3 * elementSet.size() / 2);
        for (E elt : elementSet) {
            SmallSet<E> cluster = SmallSet.create(elt);
            clustering.add(cluster);
        }
        return clustering;
    }

    private void randomInit(int[][] featuress, double[][] valss, int[] closestCenters, double[][] centroidss, Random random) {
        int numClusters = centroidss.length;
        int numElements = featuress.length;
        int[] permutation = Statistics.permutation(numElements, random);
        int[] count = new int[numClusters];
        for (int i = 0; i < numElements; ++i) {
            closestCenters[i] = i % numClusters;
        }
        this.setCentroids(centroidss, featuress, valss, closestCenters);
    }

    private void kmeansPlusPlusInit(int[][] featuress, double[][] valss, double[] eltSqLengths, int[] closestCenters, double[][] centroidss, Random random) {
        int numClusters = centroidss.length;
        int numElements = featuress.length;
        double[] sqDistToCenters = new double[numElements];
        Arrays.fill(sqDistToCenters, Double.POSITIVE_INFINITY);
        for (int k = 0; k < numClusters; ++k) {
            double[] centroidK = centroidss[k];
            int centroidIndex = k == 0 ? random.nextInt(numElements) : KMeansClusterer.sampleNextCenter(sqDistToCenters, random);
            KMeansClusterer.setCentroid(centroidK, featuress[centroidIndex], valss[centroidIndex]);
            double centroidSqLength = KMeansClusterer.selfProduct(valss[centroidIndex]);
            for (int i = 0; i < numElements; ++i) {
                double sqDistToCenter = centroidSqLength + eltSqLengths[i] - 2.0 * KMeansClusterer.product(centroidK, featuress[i], valss[i]);
                if (!(sqDistToCenter < sqDistToCenters[i])) continue;
                sqDistToCenters[i] = sqDistToCenter;
                closestCenters[i] = k;
            }
        }
        for (double[] centroid : centroidss) {
            Arrays.fill(centroid, 0.0);
        }
        this.setCentroids(centroidss, featuress, valss, closestCenters);
    }

    private void setCentroids(double[][] centroidss, int[][] featuress, double[][] valss, int[] closestCenters) {
        int numClusters = centroidss.length;
        int numElements = featuress.length;
        int[] count = new int[numClusters];
        for (int i = 0; i < numElements; ++i) {
            KMeansClusterer.increment(centroidss[closestCenters[i]], featuress[i], valss[i]);
            int n = closestCenters[i];
            count[n] = count[n] + 1;
        }
        for (int k = 0; k < numClusters; ++k) {
            double countK = count[k];
            double[] centroid = centroidss[k];
            for (int d = 0; d < centroid.length; ++d) {
                centroid[d] = centroid[d] / countK;
            }
        }
    }

    private static int sampleNextCenter(double[] probRatios, Random random) {
        double samplePoint = random.nextDouble() * KMeansClusterer.sum(probRatios);
        double total = 0.0;
        for (int i = 0; i < probRatios.length; ++i) {
            if (!((total += probRatios[i]) >= samplePoint)) continue;
            return i;
        }
        return probRatios.length - 1;
    }

    private static double[] centroidSqLengths(double[][] centroidss) {
        double[] result = new double[centroidss.length];
        for (int i = 0; i < result.length; ++i) {
            result[i] = KMeansClusterer.selfProduct(centroidss[i]);
        }
        return result;
    }

    private static double selfProduct(double[] xs) {
        double sum = 0.0;
        for (int i = 0; i < xs.length; ++i) {
            sum += xs[i] * xs[i];
        }
        return sum;
    }

    private static double sum(double[] xs) {
        double sum = 0.0;
        for (int i = 0; i < xs.length; ++i) {
            sum += xs[i];
        }
        return sum;
    }

    private static double product(double[] centroid, int[] features, double[] values) {
        double sum = 0.0;
        for (int i = 0; i < features.length; ++i) {
            sum += values[i] * centroid[features[i]];
        }
        return sum;
    }

    private static void setCentroid(double[] centroid, int[] indexes, double[] values) {
        for (int i = 0; i < indexes.length; ++i) {
            centroid[indexes[i]] = values[i];
        }
    }

    private static void increment(double[] centroid, int[] indexes, double[] values) {
        for (int i = 0; i < indexes.length; ++i) {
            int n = indexes[i];
            centroid[n] = centroid[n] + values[i];
        }
    }
}

