package io.improbable.keanu.network;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.graphtraversal.TopologicalSort;
import io.improbable.keanu.algorithms.graphtraversal.VertexValuePropagation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.NonSaveableVertex;
import io.improbable.keanu.vertices.Probabilistic;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.VertexLabel;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:io/improbable/keanu/network/BayesianNetwork.class */
public class BayesianNetwork {
    private final List<? extends Vertex> vertices;
    private final Map<VertexLabel, Vertex> vertexLabels;
    private static final int TOP_LEVEL_INDENTATION = 1;
    private int indentation;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/improbable/keanu/network/BayesianNetwork$VertexFilter.class */
    public interface VertexFilter {
        boolean filter(boolean z, boolean z2, int i);
    }

    public BayesianNetwork(Set<? extends Vertex> set) {
        this.indentation = 1;
        Preconditions.checkArgument(!set.isEmpty(), "A bayesian network must contain at least one vertex");
        this.vertices = ImmutableList.copyOf(set);
        this.vertexLabels = buildLabelMap(set);
    }

    public BayesianNetwork(Collection<? extends Vertex> collection) {
        this((Set<? extends Vertex>) new HashSet(collection));
    }

    public Vertex getVertexByLabel(VertexLabel vertexLabel) {
        Preconditions.checkArgument(this.vertexLabels.containsKey(vertexLabel), String.format("Vertex with label %s was not found in BayesianNetwork.", vertexLabel));
        return this.vertexLabels.get(vertexLabel);
    }

    public List<Vertex> getVerticesInNamespace(String... strArr) {
        return (List) this.vertices.stream().filter(vertex -> {
            return vertex.getLabel() != null && vertex.getLabel().isInNamespace(strArr);
        }).collect(Collectors.toList());
    }

    public List<Vertex> getVerticesIgnoringNamespace(String str) {
        return (List) this.vertices.stream().filter(vertex -> {
            return vertex.getLabel() != null && vertex.getLabel().getUnqualifiedName().equals(str);
        }).collect(Collectors.toList());
    }

    private Map<VertexLabel, Vertex> buildLabelMap(Set<? extends Vertex> set) {
        HashMap hashMap = new HashMap();
        for (Vertex vertex : set) {
            VertexLabel label = vertex.getLabel();
            if (vertex.getIndentation() == this.indentation && label != null) {
                if (hashMap.containsKey(label)) {
                    throw new IllegalArgumentException("Vertex Label Repeated: " + label);
                }
                hashMap.put(label, vertex);
            }
        }
        return hashMap;
    }

    public int getVertexCount() {
        return getVertices().size();
    }

    public double getAverageVertexDegree() {
        return getVertices().stream().mapToDouble((v0) -> {
            return v0.getDegree();
        }).average().getAsDouble();
    }

    public void setState(NetworkState networkState) {
        for (VariableReference variableReference : networkState.getVariableReferences()) {
            this.vertices.stream().filter(vertex -> {
                return vertex.getId() == variableReference;
            }).forEach(vertex2 -> {
                vertex2.setValue(networkState.get(variableReference));
            });
        }
    }

    public List<Vertex> getAllVertices() {
        return Collections.unmodifiableList(this.vertices);
    }

    public List<? extends Vertex> getVertices() {
        return this.vertices;
    }

    private List<Vertex> getFilteredVertexList(VertexFilter vertexFilter) {
        return (List) this.vertices.stream().filter(vertex -> {
            return vertexFilter.filter(vertex.isProbabilistic(), vertex.isObserved(), vertex.getIndentation());
        }).collect(Collectors.toList());
    }

    public List<Vertex> getLatentOrObservedVertices() {
        return getLatentOrObservedVertices(Integer.MAX_VALUE);
    }

    public List<Vertex> getTopLevelLatentOrObservedVertices() {
        return getLatentOrObservedVertices(1);
    }

    private List<Vertex> getLatentOrObservedVertices(int i) {
        return getFilteredVertexList((z, z2, i2) -> {
            return (z || z2) && i >= i2;
        });
    }

    public List<Vertex> getLatentVertices() {
        return getLatentVertices(Integer.MAX_VALUE);
    }

    public List<Vertex> getTopLevelLatentVertices() {
        return getLatentVertices(1);
    }

    private List<Vertex> getLatentVertices(int i) {
        return getFilteredVertexList((z, z2, i2) -> {
            return z && !z2 && i >= i2;
        });
    }

    public List<Vertex> getObservedVertices() {
        return getObservedVertices(Integer.MAX_VALUE);
    }

    public List<Vertex> getTopLevelObservedVertices() {
        return getObservedVertices(1);
    }

    private List<Vertex> getObservedVertices(int i) {
        return getFilteredVertexList((z, z2, i2) -> {
            return z2 && i >= i2;
        });
    }

    public double getLogOfMasterP() {
        return ProbabilityCalculator.calculateLogProbFor(getLatentOrObservedVertices());
    }

    public void cascadeObservations() {
        VertexValuePropagation.cascadeUpdate(getObservedVertices());
    }

    public void probeForNonZeroProbability(int i) {
        probeForNonZeroProbability(i, KeanuRandom.getDefaultRandom());
    }

    public void probeForNonZeroProbability(int i, KeanuRandom keanuRandom) {
        if (isInImpossibleState()) {
            List<Vertex> sort = TopologicalSort.sort(getLatentVertices());
            setFromSampleAndCascade(sort, keanuRandom);
            probeForNonZeroProbability(sort, i, keanuRandom);
        }
    }

    private void probeForNonZeroProbability(List<? extends Vertex> list, int i, KeanuRandom keanuRandom) {
        int i2 = 0;
        while (isInImpossibleState()) {
            setFromSampleAndCascade(list, keanuRandom);
            i2++;
            if (i2 > i) {
                throw new IllegalStateException("Failed to find non-zero probability state");
            }
        }
    }

    public boolean isInImpossibleState() {
        return ProbabilityCalculator.isImpossibleLogProb(getLogOfMasterP());
    }

    public static void setFromSampleAndCascade(List<? extends Vertex> list) {
        setFromSampleAndCascade(list, KeanuRandom.getDefaultRandom());
    }

    public static void setFromSampleAndCascade(List<? extends Vertex> list, KeanuRandom keanuRandom) {
        for (Vertex vertex : list) {
            if (!(vertex instanceof Probabilistic)) {
                throw new IllegalArgumentException("Cannot sample from a non-probabilistic vertex. Vertex is: " + vertex);
            }
            setValueFromSample(vertex, keanuRandom);
        }
        VertexValuePropagation.cascadeUpdate(list);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <T> void setValueFromSample(Vertex<T> vertex, KeanuRandom keanuRandom) {
        vertex.setValue(((Probabilistic) vertex).sample(keanuRandom));
    }

    public List<Vertex<DoubleTensor>> getContinuousLatentVertices() {
        return (List) getLatentVertices().stream().filter(vertex -> {
            return vertex.getValue() instanceof DoubleTensor;
        }).map(vertex2 -> {
            return vertex2;
        }).collect(Collectors.toList());
    }

    public List<Vertex> getDiscreteLatentVertices() {
        return (List) getLatentVertices().stream().filter(vertex -> {
            return !(vertex.getValue() instanceof DoubleTensor);
        }).collect(Collectors.toList());
    }

    public int getIndentation() {
        return this.indentation;
    }

    public void incrementIndentation() {
        this.indentation++;
    }

    public void save(NetworkSaver networkSaver) {
        if (!isSaveable()) {
            throw new IllegalArgumentException("Trying to save a BayesianNetwork that isn't Saveable");
        }
        Iterator<Vertex> it = TopologicalSort.sort(this.vertices).iterator();
        while (it.hasNext()) {
            it.next().save(networkSaver);
        }
    }

    private boolean isSaveable() {
        return this.vertices.stream().filter(vertex -> {
            return vertex instanceof NonSaveableVertex;
        }).count() == 0;
    }

    public void saveValues(NetworkSaver networkSaver) {
        Iterator<Vertex> it = TopologicalSort.sort(this.vertices).iterator();
        while (it.hasNext()) {
            it.next().saveValue(networkSaver);
        }
    }

    public Set<Vertex> getSubgraph(Vertex vertex, int i) {
        HashSet hashSet = new HashSet();
        ArrayList<Vertex> arrayList = new ArrayList();
        arrayList.add(vertex);
        hashSet.add(vertex);
        for (int i2 = 0; i2 < i && !arrayList.isEmpty(); i2++) {
            ArrayList arrayList2 = new ArrayList();
            for (Vertex vertex2 : arrayList) {
                Stream.concat(vertex2.getParents().stream(), vertex2.getChildren().stream()).filter(vertex3 -> {
                    return !hashSet.contains(vertex3);
                }).forEachOrdered(vertex4 -> {
                    arrayList2.add(vertex4);
                    hashSet.add(vertex4);
                });
            }
            arrayList = arrayList2;
        }
        return hashSet;
    }
}
