/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.util.io;

import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.network.NetworkSaver;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.util.io.EdgeDotLabel;
import io.improbable.keanu.util.io.GraphEdge;
import io.improbable.keanu.util.io.VertexDotLabel;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class DotSaver
implements NetworkSaver {
    private static final String DOT_HEADER = "digraph BayesianNetwork {\n";
    private static final String DOT_ENDING = "}";
    private static final String DOT_COMMENT_APPENDIX = "// ";
    private Set<VertexDotLabel> dotLabels = new HashSet<VertexDotLabel>();
    private Set<GraphEdge> graphEdges = new HashSet<GraphEdge>();
    private Set<Vertex> vertices;

    public DotSaver(BayesianNetwork network) {
        this(new HashSet<Vertex>(network.getAllVertices()));
    }

    public DotSaver(Set<Vertex> vertices) {
        this.vertices = vertices;
    }

    @Override
    public void save(OutputStream output, boolean saveValues) throws IOException {
        this.save(output, saveValues, null);
    }

    @Override
    public void save(OutputStream output, boolean saveValues, Map<String, String> metadata) throws IOException {
        this.dotLabels = new HashSet<VertexDotLabel>();
        this.graphEdges = new HashSet<GraphEdge>();
        OutputStreamWriter outputWriter = new OutputStreamWriter(output);
        for (Vertex v : this.vertices) {
            if (saveValues) {
                v.saveValue(this);
                continue;
            }
            v.save(this);
        }
        outputWriter.write(DOT_HEADER);
        DotSaver.outputMetadata(metadata, outputWriter);
        DotSaver.outputEdges(this.graphEdges, outputWriter, this.vertices);
        DotSaver.outputLabels(this.dotLabels, outputWriter);
        outputWriter.write(DOT_ENDING);
        ((Writer)outputWriter).close();
    }

    private static void outputMetadata(Map<String, String> metadata, Writer outputWriter) throws IOException {
        if (metadata != null && metadata.size() > 0) {
            outputWriter.write("// Model metadata:\n");
            for (Map.Entry<String, String> entry : metadata.entrySet()) {
                outputWriter.write(DOT_COMMENT_APPENDIX + entry.toString() + "\n");
            }
        }
    }

    private static void outputLabels(Collection<VertexDotLabel> dotLabels, Writer outputWriter) throws IOException {
        for (VertexDotLabel dotLabel : dotLabels) {
            outputWriter.write(dotLabel.inDotFormat() + "\n");
        }
    }

    private static void outputEdges(Collection<GraphEdge> edges, Writer outputWriter, Set<Vertex> verticesToOutput) throws IOException {
        for (GraphEdge edge : edges) {
            if (!verticesToOutput.contains(edge.getParentVertex()) || !verticesToOutput.contains(edge.getChildVertex())) continue;
            outputWriter.write(EdgeDotLabel.inDotFormat(edge) + "\n");
        }
    }

    @Override
    public void save(Vertex vertex) {
        this.dotLabels.add(new VertexDotLabel(vertex));
        this.graphEdges.addAll(this.getParentEdges(vertex));
    }

    @Override
    public void save(ConstantVertex vertex) {
        this.saveValue((Vertex)((Object)vertex));
    }

    @Override
    public void saveValue(Vertex vertex) {
        if (vertex.hasValue() && vertex.getValue() instanceof Tensor) {
            this.setDotLabelWithValue(vertex);
        } else {
            this.dotLabels.add(new VertexDotLabel(vertex));
        }
    }

    @Override
    public void saveValue(DoubleVertex vertex) {
        this.setDotLabelWithValue(vertex);
        this.graphEdges.addAll(this.getParentEdges(vertex));
    }

    @Override
    public void saveValue(IntegerVertex vertex) {
        this.setDotLabelWithValue(vertex);
        this.graphEdges.addAll(this.getParentEdges(vertex));
    }

    @Override
    public void saveValue(BooleanVertex vertex) {
        this.setDotLabelWithValue(vertex);
        this.graphEdges.addAll(this.getParentEdges(vertex));
    }

    private void setDotLabelWithValue(Vertex<? extends Tensor> vertex) {
        VertexDotLabel vertexDotLabel = new VertexDotLabel(vertex);
        if (vertex.hasValue() && vertex.getValue().isScalar()) {
            vertexDotLabel.setValue("" + vertex.getValue().scalar());
        }
        this.dotLabels.add(vertexDotLabel);
    }

    private Set<GraphEdge> getParentEdges(Vertex vertex) {
        Method[] methods;
        HashSet<GraphEdge> edges = new HashSet<GraphEdge>();
        for (Vertex v : vertex.getParents()) {
            edges.add(new GraphEdge(v, vertex));
        }
        Class<?> vertexClass = vertex.getClass();
        for (Method method : methods = vertexClass.getMethods()) {
            SaveVertexParam annotation = method.getAnnotation(SaveVertexParam.class);
            if (annotation == null || !Vertex.class.isAssignableFrom(method.getReturnType())) continue;
            String parentName = annotation.value();
            try {
                Vertex parentVertex = (Vertex)method.invoke((Object)vertex, new Object[0]);
                GraphEdge parentEdge = new GraphEdge(vertex, parentVertex);
                GraphEdge foundEdge = edges.stream().filter(parentEdge::equals).findFirst().orElseThrow(() -> new IllegalStateException("Did not find parent edge " + parentName));
                foundEdge.appendToLabel(parentName);
            }
            catch (Exception e) {
                throw new IllegalArgumentException("Invalid parent retrieval function specified", e);
            }
        }
        return edges;
    }
}

