/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.clustering;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.distancefunction.DistanceFunction;
import org.nd4j.linalg.distancefunction.EuclideanDistance;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KMeansClustering
implements Serializable {
    private static final long serialVersionUID = 338231277453149972L;
    private static Logger log = LoggerFactory.getLogger(KMeansClustering.class);
    private List<Long> counts = null;
    private INDArray centroids;
    private List<INDArray> initFeatures = new ArrayList<INDArray>();
    private Class<? extends DistanceFunction> clazz;
    private transient ExecutorService exec;
    private Integer nbCluster;

    public KMeansClustering(Integer nbCluster, Class<? extends DistanceFunction> clazz) {
        this.nbCluster = nbCluster;
        this.clazz = clazz;
    }

    public KMeansClustering(Integer nbCluster) {
        this(nbCluster, EuclideanDistance.class);
    }

    public Integer classify(INDArray features) {
        if (!this.isReady()) {
            throw new IllegalStateException("KMeans is not ready yet");
        }
        return this.nearestCentroid(features);
    }

    public Integer update(INDArray features) {
        if (!this.isReady()) {
            this.initIfPossible(features);
            log.info("Initializing feature vector with length of " + features.length());
            return null;
        }
        Integer nearestCentroid = this.classify(features);
        this.counts.set(nearestCentroid, this.counts.get(nearestCentroid) + 1L);
        INDArray update = features.sub(this.centroids.getRow(nearestCentroid.intValue())).mul((Number)(1.0 / (double)this.counts.get(nearestCentroid).longValue()));
        this.centroids.getRow(nearestCentroid.intValue()).addi(update);
        return nearestCentroid;
    }

    public INDArray distribution(INDArray features) {
        if (!this.isReady()) {
            throw new IllegalStateException("KMeans is not ready yet");
        }
        INDArray distribution = Nd4j.create((int)1, (int)this.nbCluster);
        for (int i = 0; i < this.nbCluster; ++i) {
            INDArray currentCentroid = this.centroids.getRow(i);
            distribution.putScalar(i, this.getDistance(currentCentroid, features));
        }
        return distribution;
    }

    private double getDistance(INDArray m1, INDArray m2) {
        DistanceFunction function;
        try {
            function = this.clazz.getConstructor(INDArray.class).newInstance(m1);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return ((Float)function.apply((Object)m2)).floatValue();
    }

    public INDArray getCentroids() {
        return this.centroids;
    }

    protected Integer nearestCentroid(INDArray features) {
        Integer nearestCentroidIndex = 0;
        double minDistance = 3.4028234663852886E38;
        for (int i = 0; i < this.centroids.rows(); ++i) {
            double currentDistance;
            INDArray currentCentroid = this.centroids.getRow(i);
            if (currentCentroid == null || !((currentDistance = this.getDistance(currentCentroid, features)) < minDistance)) continue;
            minDistance = currentDistance;
            nearestCentroidIndex = i;
        }
        return nearestCentroidIndex;
    }

    protected boolean isReady() {
        boolean countsReady = this.counts != null;
        boolean centroidsReady = this.centroids != null;
        return countsReady && centroidsReady;
    }

    protected void initIfPossible(INDArray features) {
        this.initFeatures.add(features);
        if (this.exec == null) {
            this.exec = Executors.newScheduledThreadPool(Runtime.getRuntime().availableProcessors());
        }
        log.info("Added feature vector of length " + features.length());
        if (this.initFeatures.size() >= 10 * this.nbCluster) {
            this.initCentroids();
        }
    }

    protected void initCentroids() {
        this.counts = new ArrayList<Long>(this.nbCluster);
        for (int i = 0; i < this.nbCluster; ++i) {
            this.counts.add(0L);
        }
        Random random = new Random();
        INDArray firstCentroid = this.initFeatures.remove(random.nextInt(this.initFeatures.size())).linearView();
        this.centroids = Nd4j.create((int)this.nbCluster, (int)firstCentroid.columns());
        this.centroids.putRow(0, firstCentroid);
        log.info("Added initial centroid");
        block1: for (int j = 1; j < this.nbCluster; ++j) {
            INDArray dxs = this.computeDxs();
            double r = (double)random.nextFloat() * dxs.getDouble(dxs.length() - 1);
            for (int i = 0; i < dxs.length(); ++i) {
                if (!(dxs.getDouble(i) >= r)) continue;
                INDArray features = this.initFeatures.remove(i);
                this.centroids.putRow(j, features);
                continue block1;
            }
        }
        this.initFeatures.clear();
    }

    protected INDArray computeDxs() {
        final INDArray dxs = Nd4j.create((int)this.initFeatures.size(), (int)this.initFeatures.get(0).columns());
        final AtomicInteger sum = new AtomicInteger(0);
        final CountDownLatch latch = new CountDownLatch(this.initFeatures.size());
        int i = 0;
        while (i < this.initFeatures.size()) {
            final int i2 = i++;
            this.exec.execute(new Runnable(){

                @Override
                public void run() {
                    INDArray features = (INDArray)KMeansClustering.this.initFeatures.get(i2);
                    int nearestCentroidIndex = KMeansClustering.this.nearestCentroid(features);
                    INDArray nearestCentroid = KMeansClustering.this.centroids.getRow(nearestCentroidIndex);
                    sum.getAndAdd((int)Math.pow(KMeansClustering.this.getDistance(features, nearestCentroid), 2.0));
                    dxs.putScalar(i2, sum.get());
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        return dxs;
    }

    public void reset() {
        this.counts = null;
        this.centroids = null;
        this.initFeatures = new ArrayList<INDArray>();
    }
}

