/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.clustering;

import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.clustering.Cluster;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance;
import org.nd4j.linalg.factory.Nd4j;

public class ClusterSet {
    private Class<? extends Accumulation> distanceFunction;
    private List<Cluster> clusters = new ArrayList<Cluster>();

    public ClusterSet() {
    }

    public ClusterSet(INDArray centers) {
        Integer idx = 0;
        Integer count = centers.rows();
        while (idx < count) {
            this.clusters.add(new Cluster(centers.getRow(idx.intValue())));
            Integer n = idx;
            Integer n2 = idx = Integer.valueOf(idx + 1);
        }
    }

    public ClusterSet(Class<? extends Accumulation> distanceFunction) {
        this.distanceFunction = distanceFunction;
    }

    public void addNewClusterWithCenter(INDArray center) {
        this.getClusters().add(new Cluster(center));
    }

    public INDArray getCenters() {
        INDArray centers = Nd4j.create((int)this.clusters.size(), (int)this.clusters.get(0).getCenter().columns());
        Integer idx = 0;
        Integer count = this.clusters.size();
        while (idx < count) {
            centers.putRow(idx.intValue(), this.clusters.get(idx).getCenter());
            Integer n = idx;
            Integer n2 = idx = Integer.valueOf(idx + 1);
        }
        return centers;
    }

    public void addPoint(INDArray point) {
        this.nearestCluster(point).addPoint(point, true);
    }

    public void addPoint(INDArray point, boolean moveClusterCenter) {
        this.nearestCluster(point).addPoint(point, moveClusterCenter);
    }

    public void addPoints(List<INDArray> points) {
        this.addPoints(points, true);
    }

    public void addPoints(List<INDArray> points, boolean moveClusterCenter) {
        for (INDArray point : points) {
            this.addPoint(point, moveClusterCenter);
        }
    }

    public Cluster classify(INDArray point) {
        return this.classify(point, this.distanceFunction);
    }

    public Cluster classify(INDArray point, Class<? extends Accumulation> distanceFunction) {
        return this.nearestCluster(point);
    }

    protected Cluster nearestCluster(INDArray point) {
        Cluster nearestCluster = null;
        double minDistance = 3.4028234663852886E38;
        for (Cluster cluster : this.getClusters()) {
            double currentDistance;
            INDArray currentCenter = cluster.getCenter();
            if (currentCenter == null || !((currentDistance = this.getDistance(currentCenter, point)) < minDistance)) continue;
            minDistance = currentDistance;
            nearestCluster = cluster;
        }
        return nearestCluster;
    }

    private double getDistance(INDArray m1, INDArray m2) {
        return Nd4j.getExecutioner().execAndReturn((Accumulation)new EuclideanDistance(m1, m2)).currentResult().doubleValue();
    }

    public double getDistanceFromNearestCluster(INDArray point) {
        Cluster nearestCluster = this.nearestCluster(point);
        return this.getDistance(nearestCluster.getCenter(), point);
    }

    public int getClusterCount() {
        return this.getClusters() == null ? 0 : this.getClusters().size();
    }

    public void removePoints() {
        for (Cluster cluster : this.getClusters()) {
            cluster.removePoints();
        }
    }

    public List<Cluster> getClusters() {
        return this.clusters;
    }

    public void setClusters(List<Cluster> clusters) {
        this.clusters = clusters;
    }

    public Class<? extends Accumulation> getDistanceFunction() {
        return this.distanceFunction;
    }

    public void setDistanceFunction(Class<? extends Accumulation> distanceFunction) {
        this.distanceFunction = distanceFunction;
    }
}

