package io.improbable.keanu.algorithms.mcmc;

import com.google.common.base.Preconditions;
import io.improbable.keanu.algorithms.NetworkSample;
import io.improbable.keanu.algorithms.NetworkSamples;
import io.improbable.keanu.util.status.PercentageComponent;
import io.improbable.keanu.util.status.RemainingTimeComponent;
import io.improbable.keanu.util.status.StatusBar;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.Stream;

/* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/NetworkSamplesGenerator.class */
public class NetworkSamplesGenerator {
    private final SamplingAlgorithm algorithm;
    private int dropCount = 0;
    private int downSampleInterval = 1;
    private Supplier<StatusBar> statusBarSupplier;

    public NetworkSamplesGenerator(SamplingAlgorithm samplingAlgorithm, Supplier<StatusBar> supplier) {
        this.algorithm = samplingAlgorithm;
        this.statusBarSupplier = supplier;
    }

    public int getDropCount() {
        return this.dropCount;
    }

    public NetworkSamplesGenerator dropCount(int i) {
        Preconditions.checkArgument(i >= 0, "Drop count of %s is invalid. Cannot drop negative samples.", i);
        this.dropCount = i;
        return this;
    }

    public int getDownSampleInterval() {
        return this.downSampleInterval;
    }

    public NetworkSamplesGenerator downSampleInterval(int i) {
        Preconditions.checkArgument(i > 0, "Down-sample interval of %s is invalid. The down-sample interval means take every Nth sample. A down-sample interval of 1 would be no down-sampling.", i);
        this.downSampleInterval = i;
        return this;
    }

    public NetworkSamples generate(int i) {
        Preconditions.checkArgument(this.dropCount < i, "Cannot drop more samples than requested or all of the samples. Samples requested %s and dropping %s", i, this.dropCount);
        StatusBar statusBar = this.statusBarSupplier.get();
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        dropSamples(this.dropCount, statusBar);
        PercentageComponent newPercentageComponentAndAddToStatusBar = newPercentageComponentAndAddToStatusBar(statusBar);
        RemainingTimeComponent remainingTimeComponent = new RemainingTimeComponent(i);
        statusBar.addComponent(remainingTimeComponent);
        statusBar.setMessage("Sampling...");
        int i2 = 0;
        int i3 = i - this.dropCount;
        for (int i4 = 0; i4 < i3; i4++) {
            if (i4 % this.downSampleInterval == 0) {
                this.algorithm.sample(hashMap, arrayList);
                i2++;
            } else {
                this.algorithm.step();
            }
            remainingTimeComponent.step();
            newPercentageComponentAndAddToStatusBar.progress((i4 + 1) / i3);
        }
        statusBar.finish();
        return new NetworkSamples(hashMap, arrayList, i2);
    }

    private PercentageComponent newPercentageComponentAndAddToStatusBar(StatusBar statusBar) {
        PercentageComponent percentageComponent = new PercentageComponent();
        statusBar.addComponent(percentageComponent);
        return percentageComponent;
    }

    public Stream<NetworkSample> stream() {
        StatusBar statusBar = this.statusBarSupplier.get();
        dropSamples(this.dropCount, statusBar);
        AtomicInteger atomicInteger = new AtomicInteger(0);
        Stream generate = Stream.generate(() -> {
            atomicInteger.getAndIncrement();
            for (int i = 0; i < this.downSampleInterval - 1; i++) {
                this.algorithm.step();
            }
            NetworkSample sample = this.algorithm.sample();
            statusBar.setMessage(String.format("Sample #%,d completed", Integer.valueOf(atomicInteger.get())));
            return sample;
        });
        statusBar.getClass();
        return (Stream) generate.onClose(statusBar::finish);
    }

    private void dropSamples(int i, StatusBar statusBar) {
        if (i == 0) {
            return;
        }
        statusBar.setMessage("Dropping samples...");
        PercentageComponent newPercentageComponentAndAddToStatusBar = newPercentageComponentAndAddToStatusBar(statusBar);
        for (int i2 = 0; i2 < i; i2++) {
            this.algorithm.step();
            newPercentageComponentAndAddToStatusBar.progress((i2 + 1) / i);
        }
        statusBar.removeComponent(newPercentageComponentAndAddToStatusBar);
    }
}
