package io.improbable.keanu.util.io;

import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.network.NetworkSaver;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.stream.Stream;

/* loaded from: input_file:io/improbable/keanu/util/io/DotSaver.class */
public class DotSaver implements NetworkSaver {
    private static final String DOT_HEADER = "digraph BayesianNetwork {\n";
    private static final String DOT_ENDING = "}";
    private static final String DOT_COMMENT_APPENDIX = "// ";
    private Set<VertexDotLabel> dotLabels;
    private Set<GraphEdge> graphEdges;
    private Set<Vertex> vertices;

    public DotSaver(BayesianNetwork bayesianNetwork) {
        this(new HashSet(bayesianNetwork.getAllVertices()));
    }

    public DotSaver(Set<Vertex> set) {
        this.dotLabels = new HashSet();
        this.graphEdges = new HashSet();
        this.vertices = set;
    }

    @Override // io.improbable.keanu.network.NetworkSaver
    public void save(OutputStream outputStream, boolean z) throws IOException {
        save(outputStream, z, null);
    }

    @Override // io.improbable.keanu.network.NetworkSaver
    public void save(OutputStream outputStream, boolean z, Map<String, String> map) throws IOException {
        this.dotLabels = new HashSet();
        this.graphEdges = new HashSet();
        OutputStreamWriter outputStreamWriter = new OutputStreamWriter(outputStream);
        for (Vertex vertex : this.vertices) {
            if (z) {
                vertex.saveValue(this);
            } else {
                vertex.save(this);
            }
        }
        outputStreamWriter.write(DOT_HEADER);
        outputMetadata(map, outputStreamWriter);
        outputEdges(this.graphEdges, outputStreamWriter, this.vertices);
        outputLabels(this.dotLabels, outputStreamWriter);
        outputStreamWriter.write(DOT_ENDING);
        outputStreamWriter.close();
    }

    private static void outputMetadata(Map<String, String> map, Writer writer) throws IOException {
        if (map == null || map.size() <= 0) {
            return;
        }
        writer.write("// Model metadata:\n");
        Iterator<Map.Entry<String, String>> it = map.entrySet().iterator();
        while (it.hasNext()) {
            writer.write(DOT_COMMENT_APPENDIX + it.next().toString() + io.improbable.keanu.util.csv.Writer.DEFAULT_LINE_END);
        }
    }

    private static void outputLabels(Collection<VertexDotLabel> collection, Writer writer) throws IOException {
        Iterator<VertexDotLabel> it = collection.iterator();
        while (it.hasNext()) {
            writer.write(it.next().inDotFormat() + io.improbable.keanu.util.csv.Writer.DEFAULT_LINE_END);
        }
    }

    private static void outputEdges(Collection<GraphEdge> collection, Writer writer, Set<Vertex> set) throws IOException {
        for (GraphEdge graphEdge : collection) {
            if (set.contains(graphEdge.getParentVertex()) && set.contains(graphEdge.getChildVertex())) {
                writer.write(EdgeDotLabel.inDotFormat(graphEdge) + io.improbable.keanu.util.csv.Writer.DEFAULT_LINE_END);
            }
        }
    }

    @Override // io.improbable.keanu.network.NetworkSaver
    public void save(Vertex vertex) {
        this.dotLabels.add(new VertexDotLabel(vertex));
        this.graphEdges.addAll(getParentEdges(vertex));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.network.NetworkSaver
    public void save(ConstantVertex constantVertex) {
        saveValue((Vertex) constantVertex);
    }

    @Override // io.improbable.keanu.network.NetworkSaver
    public void saveValue(Vertex vertex) {
        if (vertex.hasValue() && (vertex.getValue() instanceof Tensor)) {
            setDotLabelWithValue(vertex);
        } else {
            this.dotLabels.add(new VertexDotLabel(vertex));
        }
    }

    @Override // io.improbable.keanu.network.NetworkSaver
    public void saveValue(DoubleVertex doubleVertex) {
        setDotLabelWithValue(doubleVertex);
        this.graphEdges.addAll(getParentEdges(doubleVertex));
    }

    @Override // io.improbable.keanu.network.NetworkSaver
    public void saveValue(IntegerVertex integerVertex) {
        setDotLabelWithValue(integerVertex);
        this.graphEdges.addAll(getParentEdges(integerVertex));
    }

    @Override // io.improbable.keanu.network.NetworkSaver
    public void saveValue(BooleanVertex booleanVertex) {
        setDotLabelWithValue(booleanVertex);
        this.graphEdges.addAll(getParentEdges(booleanVertex));
    }

    private void setDotLabelWithValue(Vertex<? extends Tensor> vertex) {
        VertexDotLabel vertexDotLabel = new VertexDotLabel(vertex);
        if (vertex.hasValue() && vertex.getValue().isScalar()) {
            vertexDotLabel.setValue("" + vertex.getValue().scalar());
        }
        this.dotLabels.add(vertexDotLabel);
    }

    private Set<GraphEdge> getParentEdges(Vertex vertex) {
        HashSet hashSet = new HashSet();
        Iterator<Vertex> it = vertex.getParents().iterator();
        while (it.hasNext()) {
            hashSet.add(new GraphEdge(it.next(), vertex));
        }
        for (Method method : vertex.getClass().getMethods()) {
            SaveVertexParam saveVertexParam = (SaveVertexParam) method.getAnnotation(SaveVertexParam.class);
            if (saveVertexParam != null && Vertex.class.isAssignableFrom(method.getReturnType())) {
                String value = saveVertexParam.value();
                try {
                    GraphEdge graphEdge = new GraphEdge(vertex, (Vertex) method.invoke(vertex, new Object[0]));
                    Stream stream = hashSet.stream();
                    graphEdge.getClass();
                    ((GraphEdge) stream.filter((v1) -> {
                        return r1.equals(v1);
                    }).findFirst().orElseThrow(() -> {
                        return new IllegalStateException("Did not find parent edge " + value);
                    })).appendToLabel(value);
                } catch (Exception e) {
                    throw new IllegalArgumentException("Invalid parent retrieval function specified", e);
                }
            }
        }
        return hashSet;
    }
}
