/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.network;

import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.vertices.ProxyVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.VertexId;
import io.improbable.keanu.vertices.VertexLabel;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public final class ModelComposition {
    private ModelComposition() {
    }

    public static Map<VertexLabel, Vertex> composeModel(BayesianNetwork bayesianNetwork, Map<VertexLabel, Vertex> inputVertices, List<VertexLabel> desiredOutputs) {
        Map<VertexLabel, Vertex> outputMap = ModelComposition.extractOutputs(bayesianNetwork, desiredOutputs);
        ModelComposition.increaseDepth(bayesianNetwork, outputMap);
        ModelComposition.checkAndLinkInputs(bayesianNetwork, inputVertices);
        ModelComposition.cleanOutputLabels(outputMap);
        return outputMap;
    }

    private static Map<VertexLabel, Vertex> extractOutputs(BayesianNetwork bayesianNetwork, List<VertexLabel> desiredOutputs) {
        if (desiredOutputs.isEmpty()) {
            throw new IllegalArgumentException("At least one output must be specified");
        }
        HashMap<VertexLabel, Vertex> outputMap = new HashMap<VertexLabel, Vertex>();
        for (VertexLabel label : desiredOutputs) {
            Vertex v = bayesianNetwork.getVertexByLabel(label);
            if (v == null) {
                throw new IllegalArgumentException("Unable to find Output Vertex: " + label);
            }
            outputMap.put(label, v);
        }
        return outputMap;
    }

    private static void increaseDepth(BayesianNetwork bayesianNetwork, Map<VertexLabel, Vertex> outputVertices) {
        VertexId newPrefix = new VertexId();
        bayesianNetwork.incrementIndentation();
        bayesianNetwork.getVertices().stream().filter(v -> !outputVertices.containsKey(v.getLabel())).forEach(v -> v.getId().addPrefix(newPrefix));
        bayesianNetwork.getVertices().stream().filter(v -> outputVertices.containsKey(v.getLabel())).forEach(v -> v.getId().resetID());
    }

    private static void checkAndLinkInputs(BayesianNetwork bayesianNetwork, Map<VertexLabel, Vertex> inputs) {
        for (Map.Entry<VertexLabel, Vertex> entry : inputs.entrySet()) {
            Vertex v = bayesianNetwork.getVertexByLabel(entry.getKey());
            if (v == null) {
                throw new IllegalArgumentException("No node labelled \"" + entry.getKey() + "\" found");
            }
            if (v instanceof ProxyVertex) {
                ProxyVertex proxyVertex = (ProxyVertex)((Object)v);
                if (proxyVertex.hasParent()) {
                    throw new IllegalArgumentException("Proxy Vertex for \"" + v.getLabel() + "\" already has Parent");
                }
                proxyVertex.setParent(entry.getValue());
                continue;
            }
            throw new IllegalArgumentException("Input node \"" + entry.getKey() + "\" is not a Proxy Vertex");
        }
    }

    private static void cleanOutputLabels(Map<VertexLabel, Vertex> outputMap) {
        outputMap.values().stream().forEach(v -> v.setLabel((VertexLabel)null));
    }
}

