/*
 * Decompiled with CFR 0.152.
 */
package org.apache.commons.math4.ml.clustering;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.commons.math4.exception.NumberIsTooSmallException;
import org.apache.commons.math4.ml.clustering.CentroidCluster;
import org.apache.commons.math4.ml.clustering.Clusterable;
import org.apache.commons.math4.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.util.MathUtils;
import org.apache.commons.math4.util.Pair;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.ListSampler;

public class MiniBatchKMeansClusterer<T extends Clusterable>
extends KMeansPlusPlusClusterer<T> {
    private final int batchSize;
    private final int initIterations;
    private final int initBatchSize;
    private final int maxNoImprovementTimes;

    public MiniBatchKMeansClusterer(int k, int maxIterations, int batchSize, int initIterations, int initBatchSize, int maxNoImprovementTimes, DistanceMeasure measure, UniformRandomProvider random, KMeansPlusPlusClusterer.EmptyClusterStrategy emptyStrategy) {
        super(k, maxIterations, measure, random, emptyStrategy);
        if (batchSize < 1) {
            throw new NumberIsTooSmallException(batchSize, (Number)1, true);
        }
        if (initIterations < 1) {
            throw new NumberIsTooSmallException(initIterations, (Number)1, true);
        }
        if (initBatchSize < 1) {
            throw new NumberIsTooSmallException(initBatchSize, (Number)1, true);
        }
        if (maxNoImprovementTimes < 1) {
            throw new NumberIsTooSmallException(maxNoImprovementTimes, (Number)1, true);
        }
        this.batchSize = batchSize;
        this.initIterations = initIterations;
        this.initBatchSize = initBatchSize;
        this.maxNoImprovementTimes = maxNoImprovementTimes;
    }

    @Override
    public List<CentroidCluster<T>> cluster(Collection<T> points) {
        MathUtils.checkNotNull(points);
        if (points.size() < this.getNumberOfClusters()) {
            throw new NumberIsTooSmallException(points.size(), (Number)this.getNumberOfClusters(), false);
        }
        int pointSize = points.size();
        int batchCount = pointSize / this.batchSize + (pointSize % this.batchSize > 0 ? 1 : 0);
        int max = this.getMaxIterations() < 0 ? Integer.MAX_VALUE : this.getMaxIterations() * batchCount;
        ArrayList<T> pointList = new ArrayList<T>(points);
        List<CentroidCluster<T>> clusters = this.initialCenters(pointList);
        ImprovementEvaluator evaluator = new ImprovementEvaluator(this.batchSize, this.maxNoImprovementTimes);
        for (int i = 0; i < max; ++i) {
            this.clearClustersPoints(clusters);
            List batchPoints = ListSampler.sample((UniformRandomProvider)this.getRandomGenerator(), pointList, (int)this.batchSize);
            Pair<Double, List<CentroidCluster<T>>> pair = this.step(batchPoints, clusters);
            double squareDistance = pair.getFirst();
            clusters = pair.getSecond();
            if (evaluator.converge(squareDistance, pointSize)) break;
        }
        this.clearClustersPoints(clusters);
        for (Clusterable point : points) {
            this.addToNearestCentroidCluster(point, clusters);
        }
        return clusters;
    }

    private void clearClustersPoints(List<CentroidCluster<T>> clusters) {
        for (CentroidCluster<T> cluster : clusters) {
            cluster.getPoints().clear();
        }
    }

    private Pair<Double, List<CentroidCluster<T>>> step(List<T> batchPoints, List<CentroidCluster<T>> clusters) {
        for (Clusterable point : batchPoints) {
            this.addToNearestCentroidCluster(point, clusters);
        }
        List<CentroidCluster<T>> newClusters = this.adjustClustersCenters(clusters);
        double squareDistance = 0.0;
        for (Clusterable point : batchPoints) {
            double d = this.addToNearestCentroidCluster(point, newClusters);
            squareDistance += d * d;
        }
        return new Pair<Double, List<CentroidCluster<T>>>(squareDistance, newClusters);
    }

    private List<CentroidCluster<T>> initialCenters(List<T> points) {
        ArrayList<T> validPoints = this.initBatchSize < points.size() ? ListSampler.sample((UniformRandomProvider)this.getRandomGenerator(), points, (int)this.initBatchSize) : new ArrayList<T>(points);
        double nearestSquareDistance = Double.POSITIVE_INFINITY;
        List<CentroidCluster<T>> bestCenters = null;
        for (int i = 0; i < this.initIterations; ++i) {
            ArrayList<T> initialPoints = this.initBatchSize < points.size() ? ListSampler.sample((UniformRandomProvider)this.getRandomGenerator(), points, (int)this.initBatchSize) : new ArrayList<T>(points);
            List<CentroidCluster<T>> clusters = this.chooseInitialCenters(initialPoints);
            Pair<Double, List<CentroidCluster<T>>> pair = this.step(validPoints, clusters);
            double squareDistance = pair.getFirst();
            List<CentroidCluster<T>> newClusters = pair.getSecond();
            if (!(squareDistance < nearestSquareDistance)) continue;
            nearestSquareDistance = squareDistance;
            bestCenters = newClusters;
        }
        return bestCenters;
    }

    private double addToNearestCentroidCluster(T point, List<CentroidCluster<T>> clusters) {
        double minDistance = Double.POSITIVE_INFINITY;
        CentroidCluster<T> closestCentroidCluster = null;
        for (CentroidCluster<T> centroidCluster : clusters) {
            double distance = this.distance((Clusterable)point, centroidCluster.getCenter());
            if (!(distance < minDistance)) continue;
            minDistance = distance;
            closestCentroidCluster = centroidCluster;
        }
        MathUtils.checkNotNull(closestCentroidCluster);
        closestCentroidCluster.addPoint(point);
        return minDistance;
    }

    private static class ImprovementEvaluator {
        private final int batchSize;
        private final int maxNoImprovementTimes;
        private double ewaInertia = Double.NaN;
        private double ewaInertiaMin = Double.POSITIVE_INFINITY;
        private int noImprovementTimes;

        private ImprovementEvaluator(int batchSize, int maxNoImprovementTimes) {
            this.batchSize = batchSize;
            this.maxNoImprovementTimes = maxNoImprovementTimes;
        }

        public boolean converge(double squareDistance, int pointSize) {
            double batchInertia = squareDistance / (double)this.batchSize;
            if (Double.isNaN(this.ewaInertia)) {
                this.ewaInertia = batchInertia;
            } else {
                double alpha = Math.min(this.batchSize * 2 / (pointSize + 1), 1);
                this.ewaInertia = this.ewaInertia * (1.0 - alpha) + batchInertia * alpha;
            }
            if (this.ewaInertia < this.ewaInertiaMin) {
                this.noImprovementTimes = 0;
                this.ewaInertiaMin = this.ewaInertia;
            } else {
                ++this.noImprovementTimes;
            }
            return this.noImprovementTimes >= this.maxNoImprovementTimes;
        }
    }
}

