/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.mcmc;

import com.google.common.base.Preconditions;
import io.improbable.keanu.algorithms.NetworkSample;
import io.improbable.keanu.algorithms.NetworkSamples;
import io.improbable.keanu.algorithms.mcmc.SamplingAlgorithm;
import io.improbable.keanu.util.status.PercentageComponent;
import io.improbable.keanu.util.status.RemainingTimeComponent;
import io.improbable.keanu.util.status.StatusBar;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.Stream;

public class NetworkSamplesGenerator {
    private final SamplingAlgorithm algorithm;
    private int dropCount = 0;
    private int downSampleInterval = 1;
    private Supplier<StatusBar> statusBarSupplier;

    public NetworkSamplesGenerator(SamplingAlgorithm algorithm, Supplier<StatusBar> statusBarSupplier) {
        this.algorithm = algorithm;
        this.statusBarSupplier = statusBarSupplier;
    }

    public int getDropCount() {
        return this.dropCount;
    }

    public NetworkSamplesGenerator dropCount(int dropCount) {
        Preconditions.checkArgument((dropCount >= 0 ? 1 : 0) != 0, (String)"Drop count of %s is invalid. Cannot drop negative samples.", (int)dropCount);
        this.dropCount = dropCount;
        return this;
    }

    public int getDownSampleInterval() {
        return this.downSampleInterval;
    }

    public NetworkSamplesGenerator downSampleInterval(int downSampleInterval) {
        Preconditions.checkArgument((downSampleInterval > 0 ? 1 : 0) != 0, (String)"Down-sample interval of %s is invalid. The down-sample interval means take every Nth sample. A down-sample interval of 1 would be no down-sampling.", (int)downSampleInterval);
        this.downSampleInterval = downSampleInterval;
        return this;
    }

    public NetworkSamples generate(int totalSampleCount) {
        Preconditions.checkArgument((this.dropCount < totalSampleCount ? 1 : 0) != 0, (String)"Cannot drop more samples than requested or all of the samples. Samples requested %s and dropping %s", (int)totalSampleCount, (int)this.dropCount);
        StatusBar statusBar = this.statusBarSupplier.get();
        HashMap samplesByVariable = new HashMap();
        ArrayList<Double> logOfMasterPForEachSample = new ArrayList<Double>();
        this.dropSamples(this.dropCount, statusBar);
        PercentageComponent statusPercentage = this.newPercentageComponentAndAddToStatusBar(statusBar);
        RemainingTimeComponent remainingTimeComponent = new RemainingTimeComponent(totalSampleCount);
        statusBar.addComponent(remainingTimeComponent);
        statusBar.setMessage("Sampling...");
        int sampleCount = 0;
        int samplesLeft = totalSampleCount - this.dropCount;
        for (int i = 0; i < samplesLeft; ++i) {
            if (i % this.downSampleInterval == 0) {
                this.algorithm.sample(samplesByVariable, logOfMasterPForEachSample);
                ++sampleCount;
            } else {
                this.algorithm.step();
            }
            remainingTimeComponent.step();
            statusPercentage.progress((double)(i + 1) / (double)samplesLeft);
        }
        statusBar.finish();
        return new NetworkSamples(samplesByVariable, logOfMasterPForEachSample, sampleCount);
    }

    private PercentageComponent newPercentageComponentAndAddToStatusBar(StatusBar statusBar) {
        PercentageComponent percentageComponent = new PercentageComponent();
        statusBar.addComponent(percentageComponent);
        return percentageComponent;
    }

    public Stream<NetworkSample> stream() {
        StatusBar statusBar = this.statusBarSupplier.get();
        this.dropSamples(this.dropCount, statusBar);
        AtomicInteger sampleNumber = new AtomicInteger(0);
        return (Stream)Stream.generate(() -> {
            sampleNumber.getAndIncrement();
            for (int i = 0; i < this.downSampleInterval - 1; ++i) {
                this.algorithm.step();
            }
            NetworkSample sample = this.algorithm.sample();
            statusBar.setMessage(String.format("Sample #%,d completed", sampleNumber.get()));
            return sample;
        }).onClose(statusBar::finish);
    }

    private void dropSamples(int dropCount, StatusBar statusBar) {
        if (dropCount == 0) {
            return;
        }
        statusBar.setMessage("Dropping samples...");
        PercentageComponent statusPercent = this.newPercentageComponentAndAddToStatusBar(statusBar);
        for (int i = 0; i < dropCount; ++i) {
            this.algorithm.step();
            statusPercent.progress((double)(i + 1) / (double)dropCount);
        }
        statusBar.removeComponent(statusPercent);
    }
}

