package io.improbable.keanu.algorithms.sampling;

import com.google.common.collect.Sets;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.NetworkSamples;
import io.improbable.keanu.algorithms.PosteriorSamplingAlgorithm;
import io.improbable.keanu.algorithms.ProbabilisticModel;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.graphtraversal.TopologicalSort;
import io.improbable.keanu.algorithms.mcmc.NetworkSamplesGenerator;
import io.improbable.keanu.algorithms.mcmc.SamplingAlgorithm;
import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.network.LambdaSection;
import io.improbable.keanu.network.TransitiveClosure;
import io.improbable.keanu.util.status.StatusBar;
import io.improbable.keanu.vertices.Vertex;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.nd4j.base.Preconditions;

/* loaded from: input_file:io/improbable/keanu/algorithms/sampling/Forward.class */
public class Forward implements PosteriorSamplingAlgorithm {
    private final KeanuRandom random;
    private final boolean calculateSampleProbability;

    /* loaded from: input_file:io/improbable/keanu/algorithms/sampling/Forward$ForwardBuilder.class */
    public static class ForwardBuilder {
        private KeanuRandom random = KeanuRandom.getDefaultRandom();
        private boolean calculateSampleProbability = false;

        ForwardBuilder() {
        }

        public ForwardBuilder random(KeanuRandom keanuRandom) {
            this.random = keanuRandom;
            return this;
        }

        public ForwardBuilder calculateSampleProbability(boolean z) {
            this.calculateSampleProbability = z;
            return this;
        }

        public Forward build() {
            return new Forward(this.random, this.calculateSampleProbability);
        }

        public String toString() {
            return "ForwardBuilder(random=" + this.random + ", calculateSampleProbability=" + this.calculateSampleProbability + ")";
        }
    }

    public static ForwardBuilder builder() {
        return new ForwardBuilder();
    }

    public Forward(KeanuRandom keanuRandom, boolean z) {
        this.random = keanuRandom;
        this.calculateSampleProbability = z;
    }

    @Override // io.improbable.keanu.algorithms.PosteriorSamplingAlgorithm
    public NetworkSamples getPosteriorSamples(ProbabilisticModel probabilisticModel, List<? extends Variable> list, int i) {
        return generatePosteriorSamples(probabilisticModel, list).generate(i);
    }

    @Override // io.improbable.keanu.algorithms.PosteriorSamplingAlgorithm
    public NetworkSamplesGenerator generatePosteriorSamples(ProbabilisticModel probabilisticModel, List<? extends Variable> list) {
        return new NetworkSamplesGenerator(setupSampler(probabilisticModel, list), StatusBar::new);
    }

    private SamplingAlgorithm setupSampler(ProbabilisticModel probabilisticModel, List<? extends Variable> list) {
        Preconditions.checkArgument(probabilisticModel.getLatentVariables().size() > 0, "Your model must contain latent variables in order to forward sample.");
        ArrayList arrayList = new ArrayList();
        for (Variable variable : list) {
            Preconditions.checkArgument(variable instanceof Vertex, "The Forward Sampler only works for Variables of type Vertex. Received : " + variable);
            arrayList.add((Vertex) variable);
        }
        BayesianNetwork checkSampleFromVariablesComeFromConnectedGraph = checkSampleFromVariablesComeFromConnectedGraph(list);
        checkUpstreamOfObservedDoesNotContainProbabilistic(checkSampleFromVariablesComeFromConnectedGraph.getObservedVertices());
        return new ForwardSampler(checkSampleFromVariablesComeFromConnectedGraph, arrayList, TopologicalSort.sort(Sets.intersection(allDownstreamVertices(checkSampleFromVariablesComeFromConnectedGraph.getLatentVertices()), TransitiveClosure.getUpstreamVerticesForCollection(arrayList, true).getAllVertices())), this.random, this.calculateSampleProbability);
    }

    private BayesianNetwork checkSampleFromVariablesComeFromConnectedGraph(List<? extends Variable> list) {
        Set<Vertex> connectedGraph = ((Vertex) list.get(0)).getConnectedGraph();
        Iterator<? extends Variable> it = list.iterator();
        while (it.hasNext()) {
            if (!connectedGraph.contains(it.next())) {
                throw new IllegalArgumentException("Sample from vertices must be part of the same connected graph.");
            }
        }
        return new BayesianNetwork((Set<? extends Vertex>) connectedGraph);
    }

    private Set<Vertex> allDownstreamVertices(List<Vertex> list) {
        return LambdaSection.getDownstreamLambdaSectionForCollection(list, true).getAllVertices();
    }

    private void checkUpstreamOfObservedDoesNotContainProbabilistic(List<Vertex> list) {
        if (LambdaSection.getUpstreamLambdaSectionForCollection(list, false).getAllVertices().size() > 1) {
            throw new IllegalArgumentException("Forward sampler cannot be ran if observed variables have a random variable in their upstream lambda section");
        }
    }
}
