package io.improbable.keanu.network;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.improbable.keanu.algorithms.ProbabilisticModel;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.graphtraversal.VertexValuePropagation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import io.improbable.keanu.vertices.Vertex;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:io/improbable/keanu/network/KeanuProbabilisticModel.class */
public class KeanuProbabilisticModel implements ProbabilisticModel {
    private final Map<VariableReference, Vertex> vertexLookup;
    private final List<Vertex> latentVertices;
    private final List<Vertex> observedVertices;
    private final List<Vertex> latentOrObservedVertices;
    private final LambdaSectionSnapshot lambdaSectionSnapshot;

    public KeanuProbabilisticModel(Collection<? extends Vertex> collection) {
        this(new BayesianNetwork(collection));
    }

    public KeanuProbabilisticModel(BayesianNetwork bayesianNetwork) {
        this.vertexLookup = (Map) bayesianNetwork.getLatentOrObservedVertices().stream().collect(Collectors.toMap((v0) -> {
            return v0.getId();
        }, vertex -> {
            return vertex;
        }));
        this.latentVertices = ImmutableList.copyOf(bayesianNetwork.getLatentVertices());
        this.observedVertices = ImmutableList.copyOf(bayesianNetwork.getObservedVertices());
        this.latentOrObservedVertices = ImmutableList.copyOf(bayesianNetwork.getLatentOrObservedVertices());
        this.lambdaSectionSnapshot = new LambdaSectionSnapshot();
        resetModelToObservedState();
        checkBayesNetInHealthyState();
    }

    @Override // io.improbable.keanu.algorithms.ProbabilisticModel
    public double logProb(Map<VariableReference, ?> map) {
        if (!map.isEmpty()) {
            cascadeValues(map);
        }
        return ProbabilityCalculator.calculateLogProbFor(this.latentOrObservedVertices);
    }

    @Override // io.improbable.keanu.algorithms.ProbabilisticModel
    public double logProbAfter(Map<VariableReference, Object> map, double d) {
        ImmutableSet.Builder builder = ImmutableSet.builder();
        Iterator<VariableReference> it = map.keySet().iterator();
        while (it.hasNext()) {
            builder.add(this.vertexLookup.get(it.next()));
        }
        Set<? extends Variable> build = builder.build();
        double logProb = this.lambdaSectionSnapshot.logProb(build);
        cascadeValues(map);
        return d + (this.lambdaSectionSnapshot.logProb(build) - logProb);
    }

    @Override // io.improbable.keanu.algorithms.ProbabilisticModel
    public double logLikelihood(Map<VariableReference, ?> map) {
        if (!map.isEmpty()) {
            cascadeValues(map);
        }
        return ProbabilityCalculator.calculateLogProbFor(this.observedVertices);
    }

    @Override // io.improbable.keanu.algorithms.ProbabilisticModel
    public List<Variable> getLatentVariables() {
        return this.latentVertices;
    }

    public List<Vertex> getLatentVertices() {
        return this.latentVertices;
    }

    public List<Vertex> getLatentOrObservedVertices() {
        return this.latentOrObservedVertices;
    }

    @Override // io.improbable.keanu.algorithms.ProbabilisticModel
    public List<Variable<DoubleTensor, ?>> getContinuousLatentVariables() {
        return (List) getLatentVariables().stream().filter(variable -> {
            return variable.getValue() instanceof DoubleTensor;
        }).map(variable2 -> {
            return variable2;
        }).collect(Collectors.toList());
    }

    private void checkBayesNetInHealthyState() {
        if (this.latentOrObservedVertices.isEmpty()) {
            throw new IllegalArgumentException("Cannot create model without latent or observed variables");
        }
    }

    private void resetModelToObservedState() {
        VertexValuePropagation.cascadeUpdate(this.observedVertices);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void cascadeValues(Map<VariableReference, ?> map) {
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<VariableReference, ?> entry : map.entrySet()) {
            Vertex vertex = this.vertexLookup.get(entry.getKey());
            if (vertex == null) {
                throw new IllegalArgumentException("Cannot cascade update for input: " + entry.getKey());
            }
            vertex.setValue(entry.getValue());
            arrayList.add(vertex);
        }
        VertexValuePropagation.cascadeUpdate(arrayList);
    }
}
