/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.sampling;

import com.google.common.collect.Sets;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.NetworkSamples;
import io.improbable.keanu.algorithms.PosteriorSamplingAlgorithm;
import io.improbable.keanu.algorithms.ProbabilisticModel;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.graphtraversal.TopologicalSort;
import io.improbable.keanu.algorithms.mcmc.NetworkSamplesGenerator;
import io.improbable.keanu.algorithms.mcmc.SamplingAlgorithm;
import io.improbable.keanu.algorithms.sampling.ForwardSampler;
import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.network.LambdaSection;
import io.improbable.keanu.network.TransitiveClosure;
import io.improbable.keanu.util.status.StatusBar;
import io.improbable.keanu.vertices.Vertex;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import org.nd4j.base.Preconditions;

public class Forward
implements PosteriorSamplingAlgorithm {
    private final KeanuRandom random;
    private final boolean calculateSampleProbability;

    public static ForwardBuilder builder() {
        return new ForwardBuilder();
    }

    public Forward(KeanuRandom random, boolean calculateSampleProbability) {
        this.random = random;
        this.calculateSampleProbability = calculateSampleProbability;
    }

    @Override
    public NetworkSamples getPosteriorSamples(ProbabilisticModel model, List<? extends Variable> variablesToSampleFrom, int sampleCount) {
        return this.generatePosteriorSamples(model, variablesToSampleFrom).generate(sampleCount);
    }

    @Override
    public NetworkSamplesGenerator generatePosteriorSamples(ProbabilisticModel model, List<? extends Variable> variablesToSampleFrom) {
        return new NetworkSamplesGenerator(this.setupSampler(model, variablesToSampleFrom), StatusBar::new);
    }

    private SamplingAlgorithm setupSampler(ProbabilisticModel model, List<? extends Variable> variablesToSampleFrom) {
        List<Variable> latentVariables = model.getLatentVariables();
        Preconditions.checkArgument((latentVariables.size() > 0 ? 1 : 0) != 0, (String)"Your model must contain latent variables in order to forward sample.");
        ArrayList<Vertex> verticesToSampleFrom = new ArrayList<Vertex>();
        for (Variable variable : variablesToSampleFrom) {
            Preconditions.checkArgument((boolean)(variable instanceof Vertex), (String)("The Forward Sampler only works for Variables of type Vertex. Received : " + variable));
            verticesToSampleFrom.add((Vertex)variable);
        }
        BayesianNetwork network = this.checkSampleFromVariablesComeFromConnectedGraph(variablesToSampleFrom);
        List<Vertex> list = network.getObservedVertices();
        this.checkUpstreamOfObservedDoesNotContainProbabilistic(list);
        Set<Vertex> allDownstreamVertices = this.allDownstreamVertices(network.getLatentVertices());
        Set<Vertex> transitiveClosureSampleFrom = TransitiveClosure.getUpstreamVerticesForCollection(verticesToSampleFrom, true).getAllVertices();
        Sets.SetView intersection = Sets.intersection(allDownstreamVertices, transitiveClosureSampleFrom);
        List<Vertex> sortedVertices = TopologicalSort.sort((Collection<? extends Vertex>)intersection);
        return new ForwardSampler(network, verticesToSampleFrom, sortedVertices, this.random, this.calculateSampleProbability);
    }

    private BayesianNetwork checkSampleFromVariablesComeFromConnectedGraph(List<? extends Variable> variablesToSampleFrom) {
        Variable variable = variablesToSampleFrom.get(0);
        Set<Vertex> connectedGraph = ((Vertex)variable).getConnectedGraph();
        for (Variable variable2 : variablesToSampleFrom) {
            if (connectedGraph.contains(variable2)) continue;
            throw new IllegalArgumentException("Sample from vertices must be part of the same connected graph.");
        }
        return new BayesianNetwork(connectedGraph);
    }

    private Set<Vertex> allDownstreamVertices(List<Vertex> randomVertices) {
        return LambdaSection.getDownstreamLambdaSectionForCollection(randomVertices, true).getAllVertices();
    }

    private void checkUpstreamOfObservedDoesNotContainProbabilistic(List<Vertex> observedVertices) {
        LambdaSection upstreamLambdaSection = LambdaSection.getUpstreamLambdaSectionForCollection(observedVertices, false);
        Set<Vertex> upstreamRandomVariables = upstreamLambdaSection.getAllVertices();
        if (upstreamRandomVariables.size() > 1) {
            throw new IllegalArgumentException("Forward sampler cannot be ran if observed variables have a random variable in their upstream lambda section");
        }
    }

    public static class ForwardBuilder {
        private KeanuRandom random = KeanuRandom.getDefaultRandom();
        private boolean calculateSampleProbability = false;

        ForwardBuilder() {
        }

        public ForwardBuilder random(KeanuRandom random) {
            this.random = random;
            return this;
        }

        public ForwardBuilder calculateSampleProbability(boolean calculateSampleProbability) {
            this.calculateSampleProbability = calculateSampleProbability;
            return this;
        }

        public Forward build() {
            return new Forward(this.random, this.calculateSampleProbability);
        }

        public String toString() {
            return "ForwardBuilder(random=" + this.random + ", calculateSampleProbability=" + this.calculateSampleProbability + ")";
        }
    }
}

