/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.network;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.graphtraversal.TopologicalSort;
import io.improbable.keanu.algorithms.graphtraversal.VertexValuePropagation;
import io.improbable.keanu.network.NetworkSaver;
import io.improbable.keanu.network.NetworkState;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.NonSaveableVertex;
import io.improbable.keanu.vertices.Probabilistic;
import io.improbable.keanu.vertices.ProbabilityCalculator;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.VertexLabel;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class BayesianNetwork {
    private final List<? extends Vertex> vertices;
    private final Map<VertexLabel, Vertex> vertexLabels;
    private static final int TOP_LEVEL_INDENTATION = 1;
    private int indentation = 1;

    public BayesianNetwork(Set<? extends Vertex> vertices) {
        Preconditions.checkArgument((!vertices.isEmpty() ? 1 : 0) != 0, (Object)"A bayesian network must contain at least one vertex");
        this.vertices = ImmutableList.copyOf(vertices);
        this.vertexLabels = this.buildLabelMap(vertices);
    }

    public BayesianNetwork(Collection<? extends Vertex> vertices) {
        this((Set<? extends Vertex>)new HashSet<Vertex>(vertices));
    }

    public Vertex getVertexByLabel(VertexLabel label) {
        Preconditions.checkArgument((boolean)this.vertexLabels.containsKey(label), (Object)String.format("Vertex with label %s was not found in BayesianNetwork.", label));
        return this.vertexLabels.get(label);
    }

    public List<Vertex> getVerticesInNamespace(String ... namespace) {
        return this.vertices.stream().filter(v -> v.getLabel() != null && v.getLabel().isInNamespace(namespace)).collect(Collectors.toList());
    }

    public List<Vertex> getVerticesIgnoringNamespace(String innerNamespace) {
        return this.vertices.stream().filter(v -> v.getLabel() != null && v.getLabel().getUnqualifiedName().equals(innerNamespace)).collect(Collectors.toList());
    }

    private Map<VertexLabel, Vertex> buildLabelMap(Set<? extends Vertex> vertices) {
        HashMap<VertexLabel, Vertex> labelMap = new HashMap<VertexLabel, Vertex>();
        for (Vertex vertex : vertices) {
            VertexLabel label = vertex.getLabel();
            if (vertex.getIndentation() != this.indentation || label == null) continue;
            if (labelMap.containsKey(label)) {
                throw new IllegalArgumentException("Vertex Label Repeated: " + label);
            }
            labelMap.put(label, vertex);
        }
        return labelMap;
    }

    public int getVertexCount() {
        return this.getVertices().size();
    }

    public double getAverageVertexDegree() {
        return this.getVertices().stream().mapToDouble(Vertex::getDegree).average().getAsDouble();
    }

    public void setState(NetworkState state) {
        for (VariableReference reference : state.getVariableReferences()) {
            this.vertices.stream().filter(v -> v.getId() == reference).forEach(v -> v.setValue(state.get(reference)));
        }
    }

    public List<Vertex> getAllVertices() {
        return Collections.unmodifiableList(this.vertices);
    }

    public List<? extends Vertex> getVertices() {
        return this.vertices;
    }

    private List<Vertex> getFilteredVertexList(VertexFilter filter) {
        return this.vertices.stream().filter(v -> filter.filter(v.isProbabilistic(), v.isObserved(), v.getIndentation())).collect(Collectors.toList());
    }

    public List<Vertex> getLatentOrObservedVertices() {
        return this.getLatentOrObservedVertices(Integer.MAX_VALUE);
    }

    public List<Vertex> getTopLevelLatentOrObservedVertices() {
        return this.getLatentOrObservedVertices(1);
    }

    private List<Vertex> getLatentOrObservedVertices(int maxIndentation) {
        return this.getFilteredVertexList((isProbabilistic, isObserved, indentation) -> (isProbabilistic || isObserved) && maxIndentation >= indentation);
    }

    public List<Vertex> getLatentVertices() {
        return this.getLatentVertices(Integer.MAX_VALUE);
    }

    public List<Vertex> getTopLevelLatentVertices() {
        return this.getLatentVertices(1);
    }

    private List<Vertex> getLatentVertices(int maxIndentation) {
        return this.getFilteredVertexList((isProbabilistic, isObserved, indentation) -> isProbabilistic && !isObserved && maxIndentation >= indentation);
    }

    public List<Vertex> getObservedVertices() {
        return this.getObservedVertices(Integer.MAX_VALUE);
    }

    public List<Vertex> getTopLevelObservedVertices() {
        return this.getObservedVertices(1);
    }

    private List<Vertex> getObservedVertices(int maxIndentation) {
        return this.getFilteredVertexList((isProbabilistic, isObserved, indentation) -> isObserved && maxIndentation >= indentation);
    }

    public double getLogOfMasterP() {
        return ProbabilityCalculator.calculateLogProbFor(this.getLatentOrObservedVertices());
    }

    public void cascadeObservations() {
        VertexValuePropagation.cascadeUpdate(this.getObservedVertices());
    }

    public void probeForNonZeroProbability(int attempts) {
        this.probeForNonZeroProbability(attempts, KeanuRandom.getDefaultRandom());
    }

    public void probeForNonZeroProbability(int attempts, KeanuRandom random) {
        if (this.isInImpossibleState()) {
            List<Vertex> sortedByDependency = TopologicalSort.sort(this.getLatentVertices());
            BayesianNetwork.setFromSampleAndCascade(sortedByDependency, random);
            this.probeForNonZeroProbability(sortedByDependency, attempts, random);
        }
    }

    private void probeForNonZeroProbability(List<? extends Vertex> latentVertices, int attempts, KeanuRandom random) {
        int iteration = 0;
        while (this.isInImpossibleState()) {
            BayesianNetwork.setFromSampleAndCascade(latentVertices, random);
            if (++iteration <= attempts) continue;
            throw new IllegalStateException("Failed to find non-zero probability state");
        }
    }

    public boolean isInImpossibleState() {
        return ProbabilityCalculator.isImpossibleLogProb(this.getLogOfMasterP());
    }

    public static void setFromSampleAndCascade(List<? extends Vertex> vertices) {
        BayesianNetwork.setFromSampleAndCascade(vertices, KeanuRandom.getDefaultRandom());
    }

    public static void setFromSampleAndCascade(List<? extends Vertex> vertices, KeanuRandom random) {
        for (Vertex vertex : vertices) {
            if (!(vertex instanceof Probabilistic)) {
                throw new IllegalArgumentException("Cannot sample from a non-probabilistic vertex. Vertex is: " + vertex);
            }
            BayesianNetwork.setValueFromSample(vertex, random);
        }
        VertexValuePropagation.cascadeUpdate(vertices);
    }

    private static <T> void setValueFromSample(Vertex<T> vertex, KeanuRandom random) {
        vertex.setValue(((Probabilistic)((Object)vertex)).sample(random));
    }

    public List<Vertex<DoubleTensor>> getContinuousLatentVertices() {
        return this.getLatentVertices().stream().filter(v -> v.getValue() instanceof DoubleTensor).map(v -> v).collect(Collectors.toList());
    }

    public List<Vertex> getDiscreteLatentVertices() {
        return this.getLatentVertices().stream().filter(v -> !(v.getValue() instanceof DoubleTensor)).collect(Collectors.toList());
    }

    public int getIndentation() {
        return this.indentation;
    }

    public void incrementIndentation() {
        ++this.indentation;
    }

    public void save(NetworkSaver networkSaver) {
        if (this.isSaveable()) {
            for (Vertex vertex : TopologicalSort.sort(this.vertices)) {
                vertex.save(networkSaver);
            }
        } else {
            throw new IllegalArgumentException("Trying to save a BayesianNetwork that isn't Saveable");
        }
    }

    private boolean isSaveable() {
        return this.vertices.stream().filter(v -> v instanceof NonSaveableVertex).count() == 0L;
    }

    public void saveValues(NetworkSaver networkSaver) {
        for (Vertex vertex : TopologicalSort.sort(this.vertices)) {
            vertex.saveValue(networkSaver);
        }
    }

    public Set<Vertex> getSubgraph(Vertex vertex, int degree) {
        HashSet<Vertex> subgraphVertices = new HashSet<Vertex>();
        ArrayList<Vertex> verticesToProcessNow = new ArrayList<Vertex>();
        verticesToProcessNow.add(vertex);
        subgraphVertices.add(vertex);
        for (int distance = 0; distance < degree && !verticesToProcessNow.isEmpty(); ++distance) {
            ArrayList connectedVertices = new ArrayList();
            for (Vertex v : verticesToProcessNow) {
                Stream verticesToAdd = Stream.concat(v.getParents().stream(), v.getChildren().stream());
                verticesToAdd.filter(a -> !subgraphVertices.contains(a)).forEachOrdered(a -> {
                    connectedVertices.add(a);
                    subgraphVertices.add((Vertex)a);
                });
            }
            verticesToProcessNow = connectedVertices;
        }
        return subgraphVertices;
    }

    private static interface VertexFilter {
        public boolean filter(boolean var1, boolean var2, int var3);
    }
}

