package io.improbable.keanu.algorithms.sampling;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.NetworkSample;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.mcmc.SamplingAlgorithm;
import io.improbable.keanu.algorithms.mcmc.SamplingUtil;
import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.Probabilistic;
import io.improbable.keanu.vertices.Vertex;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/algorithms/sampling/ForwardSampler.class */
public class ForwardSampler implements SamplingAlgorithm {
    private static final double LOG_PROB_OF_PRIOR = 0.0d;
    private final BayesianNetwork network;
    private final List<? extends Variable> variablesToSampleFrom;
    private final List<Vertex> topologicallySortedVertices;
    private final KeanuRandom random;
    private final boolean calculateSampleProbability;

    public ForwardSampler(BayesianNetwork bayesianNetwork, List<? extends Variable> list, List<Vertex> list2, KeanuRandom keanuRandom, boolean z) {
        this.network = bayesianNetwork;
        this.variablesToSampleFrom = list;
        this.topologicallySortedVertices = list2;
        this.random = keanuRandom;
        this.calculateSampleProbability = z;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.algorithms.mcmc.SamplingAlgorithm
    public void step() {
        for (Vertex vertex : this.topologicallySortedVertices) {
            if (vertex instanceof Probabilistic) {
                vertex.setValue(((Probabilistic) vertex).sample(this.random));
            } else {
                if (!(vertex instanceof NonProbabilistic)) {
                    throw new IllegalArgumentException("Forward sampler can only operate on Probabilistic or NonProbabilistic vertices. Invalid Vertex: [" + vertex + "]");
                }
                vertex.setValue(((NonProbabilistic) vertex).calculate());
            }
        }
    }

    @Override // io.improbable.keanu.algorithms.mcmc.SamplingAlgorithm
    public void sample(Map<VariableReference, List<?>> map, List<Double> list) {
        step();
        SamplingUtil.takeSamples(map, this.variablesToSampleFrom);
        list.add(Double.valueOf(this.calculateSampleProbability ? this.network.getLogOfMasterP() : LOG_PROB_OF_PRIOR));
    }

    @Override // io.improbable.keanu.algorithms.mcmc.SamplingAlgorithm
    public NetworkSample sample() {
        step();
        return new NetworkSample(SamplingAlgorithm.takeSample(this.variablesToSampleFrom), LOG_PROB_OF_PRIOR);
    }
}
