package io.improbable.keanu.network;

import io.improbable.keanu.vertices.ProxyVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.VertexId;
import io.improbable.keanu.vertices.VertexLabel;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/network/ModelComposition.class */
public final class ModelComposition {
    private ModelComposition() {
    }

    public static Map<VertexLabel, Vertex> composeModel(BayesianNetwork bayesianNetwork, Map<VertexLabel, Vertex> map, List<VertexLabel> list) {
        Map<VertexLabel, Vertex> extractOutputs = extractOutputs(bayesianNetwork, list);
        increaseDepth(bayesianNetwork, extractOutputs);
        checkAndLinkInputs(bayesianNetwork, map);
        cleanOutputLabels(extractOutputs);
        return extractOutputs;
    }

    private static Map<VertexLabel, Vertex> extractOutputs(BayesianNetwork bayesianNetwork, List<VertexLabel> list) {
        if (list.isEmpty()) {
            throw new IllegalArgumentException("At least one output must be specified");
        }
        HashMap hashMap = new HashMap();
        for (VertexLabel vertexLabel : list) {
            Vertex vertexByLabel = bayesianNetwork.getVertexByLabel(vertexLabel);
            if (vertexByLabel == null) {
                throw new IllegalArgumentException("Unable to find Output Vertex: " + vertexLabel);
            }
            hashMap.put(vertexLabel, vertexByLabel);
        }
        return hashMap;
    }

    private static void increaseDepth(BayesianNetwork bayesianNetwork, Map<VertexLabel, Vertex> map) {
        VertexId vertexId = new VertexId();
        bayesianNetwork.incrementIndentation();
        bayesianNetwork.getVertices().stream().filter(vertex -> {
            return !map.containsKey(vertex.getLabel());
        }).forEach(vertex2 -> {
            vertex2.getId().addPrefix(vertexId);
        });
        bayesianNetwork.getVertices().stream().filter(vertex3 -> {
            return map.containsKey(vertex3.getLabel());
        }).forEach(vertex4 -> {
            vertex4.getId().resetID();
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static void checkAndLinkInputs(BayesianNetwork bayesianNetwork, Map<VertexLabel, Vertex> map) {
        for (Map.Entry<VertexLabel, Vertex> entry : map.entrySet()) {
            Vertex vertexByLabel = bayesianNetwork.getVertexByLabel(entry.getKey());
            if (vertexByLabel == 0) {
                throw new IllegalArgumentException("No node labelled \"" + entry.getKey() + "\" found");
            }
            if (!(vertexByLabel instanceof ProxyVertex)) {
                throw new IllegalArgumentException("Input node \"" + entry.getKey() + "\" is not a Proxy Vertex");
            }
            ProxyVertex proxyVertex = (ProxyVertex) vertexByLabel;
            if (proxyVertex.hasParent()) {
                throw new IllegalArgumentException("Proxy Vertex for \"" + vertexByLabel.getLabel() + "\" already has Parent");
            }
            proxyVertex.setParent(entry.getValue());
        }
    }

    private static void cleanOutputLabels(Map<VertexLabel, Vertex> map) {
        map.values().stream().forEach(vertex -> {
            vertex.setLabel((VertexLabel) null);
        });
    }
}
