package io.improbable.keanu.util.io;

import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.network.NetworkSaver;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.NonSaveableVertex;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import io.improbable.mir.KeanuSavedBayesNet;
import io.improbable.mir.SavedBayesNet;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import org.apache.commons.math3.util.Pair;

/* loaded from: input_file:io/improbable/keanu/util/io/ProtobufSaver.class */
public class ProtobufSaver implements NetworkSaver {
    private final BayesianNetwork net;
    private SavedBayesNet.Graph.Builder graphBuilder = null;

    public ProtobufSaver(BayesianNetwork bayesianNetwork) {
        this.net = bayesianNetwork;
    }

    @Override // io.improbable.keanu.network.NetworkSaver
    public void save(OutputStream outputStream, boolean z, Map<String, String> map) throws IOException {
        getModel(z, map).writeTo(outputStream);
        clearGraph();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public KeanuSavedBayesNet.ProtoModel getModel(boolean z, Map<String, String> map) {
        KeanuSavedBayesNet.ProtoModel.Builder graph = KeanuSavedBayesNet.ProtoModel.newBuilder().setGraph(getGraph(z));
        if (map != null) {
            graph.setMetadata(buildMetadata(map));
        }
        return graph.m141build();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SavedBayesNet.Graph getGraph(boolean z) {
        createGraph(z);
        return this.graphBuilder.m860build();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void clearGraph() {
        this.graphBuilder = null;
    }

    private void createGraph(boolean z) {
        this.graphBuilder = SavedBayesNet.Graph.newBuilder();
        this.net.save(this);
        if (z) {
            this.net.saveValues(this);
        }
    }

    private KeanuSavedBayesNet.ModelMetadata buildMetadata(Map<String, String> map) {
        KeanuSavedBayesNet.ModelMetadata.Builder newBuilder = KeanuSavedBayesNet.ModelMetadata.newBuilder();
        String[] strArr = (String[]) map.keySet().toArray(new String[0]);
        Arrays.sort(strArr);
        for (String str : strArr) {
            newBuilder.putMetadataInfo(str, map.get(str));
        }
        return newBuilder.m93build();
    }

    @Override // io.improbable.keanu.network.NetworkSaver
    public void save(Vertex vertex) {
        if (vertex instanceof NonSaveableVertex) {
            throw new IllegalArgumentException("Trying to save a vertex that isn't Saveable");
        }
        this.graphBuilder.addVertices(buildVertex(vertex));
    }

    private SavedBayesNet.Vertex buildVertex(Vertex vertex) {
        SavedBayesNet.Vertex.Builder newBuilder = SavedBayesNet.Vertex.newBuilder();
        if (vertex.getLabel() != null) {
            newBuilder = newBuilder.setLabel(vertex.getLabel().toString());
        }
        SavedBayesNet.Vertex.Builder addAllShape = newBuilder.setId(SavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString())).setVertexType(vertex.getClass().getCanonicalName()).addAllShape(Longs.asList(vertex.getShape()));
        saveParams(addAllShape, vertex);
        return addAllShape.m1190build();
    }

    private void saveParams(SavedBayesNet.Vertex.Builder builder, Vertex vertex) {
        Map<String, Pair<Method, Boolean>> parentRetrievalMethodMap = getParentRetrievalMethodMap(vertex);
        String[] strArr = (String[]) parentRetrievalMethodMap.keySet().toArray(new String[0]);
        Arrays.sort(strArr);
        for (String str : strArr) {
            SavedBayesNet.NamedParam encodedParam = getEncodedParam(vertex, str, (Method) parentRetrievalMethodMap.get(str).getFirst(), ((Boolean) parentRetrievalMethodMap.get(str).getSecond()).booleanValue());
            if (encodedParam != null) {
                builder.addParameters(encodedParam);
            }
        }
    }

    private Map<String, Pair<Method, Boolean>> getParentRetrievalMethodMap(Vertex vertex) {
        Method[] methods = vertex.getClass().getMethods();
        HashMap hashMap = new HashMap();
        for (Method method : methods) {
            SaveVertexParam saveVertexParam = (SaveVertexParam) method.getAnnotation(SaveVertexParam.class);
            if (saveVertexParam != null) {
                hashMap.put(saveVertexParam.value(), new Pair(method, Boolean.valueOf(saveVertexParam.isNullable())));
            }
        }
        return hashMap;
    }

    private SavedBayesNet.NamedParam getEncodedParam(Vertex vertex, String str, Method method, boolean z) {
        try {
            Object invoke = method.invoke(vertex, new Object[0]);
            if (invoke != null) {
                return getTypedParam(str, invoke);
            }
            if (z) {
                return null;
            }
            throw new IllegalArgumentException("No value returned from Save Function");
        } catch (Exception e) {
            throw new IllegalArgumentException("Invalid parent retrieval function specified", e);
        }
    }

    private SavedBayesNet.NamedParam getTypedParam(String str, Object obj) {
        if (Vertex.class.isAssignableFrom(obj.getClass())) {
            return getParam(str, (Vertex) obj);
        }
        if (DoubleTensor.class.isAssignableFrom(obj.getClass())) {
            return getParam(str, builder -> {
                builder.setDoubleTensorParam(getTensor((DoubleTensor) obj));
            });
        }
        if (IntegerTensor.class.isAssignableFrom(obj.getClass())) {
            return getParam(str, builder2 -> {
                builder2.setIntTensorParam(getTensor((IntegerTensor) obj));
            });
        }
        if (BooleanTensor.class.isAssignableFrom(obj.getClass())) {
            return getParam(str, builder3 -> {
                builder3.setBoolTensorParam(getTensor((BooleanTensor) obj));
            });
        }
        if (Double.class.isAssignableFrom(obj.getClass())) {
            return getParam(str, builder4 -> {
                builder4.setDoubleParam(((Double) obj).doubleValue());
            });
        }
        if (Integer.class.isAssignableFrom(obj.getClass())) {
            return getParam(str, builder5 -> {
                builder5.setIntParam(((Integer) obj).intValue());
            });
        }
        if (Long.class.isAssignableFrom(obj.getClass())) {
            return getParam(str, builder6 -> {
                builder6.setLongParam(((Long) obj).longValue());
            });
        }
        if (String.class.isAssignableFrom(obj.getClass())) {
            return getParam(str, builder7 -> {
                builder7.setStringParam((String) obj);
            });
        }
        if (Boolean.class.isAssignableFrom(obj.getClass())) {
            return getParam(str, builder8 -> {
                builder8.setBoolParam(((Boolean) obj).booleanValue());
            });
        }
        if (Long[].class.isAssignableFrom(obj.getClass())) {
            return getParam(str, (long[]) obj);
        }
        if (Vertex[].class.isAssignableFrom(obj.getClass())) {
            return getParam(str, (Vertex[]) obj);
        }
        if (Integer[].class.isAssignableFrom(obj.getClass())) {
            return getParam(str, (int[]) obj);
        }
        throw new IllegalArgumentException("Unknown Parameter Type to Save: " + obj.getClass().toString());
    }

    private SavedBayesNet.NamedParam getParam(String str, Consumer<SavedBayesNet.NamedParam.Builder> consumer) {
        SavedBayesNet.NamedParam.Builder newBuilder = SavedBayesNet.NamedParam.newBuilder();
        newBuilder.setName(str);
        consumer.accept(newBuilder);
        return newBuilder.m1095build();
    }

    private SavedBayesNet.NamedParam getParam(String str, Vertex vertex) {
        return getParam(str, builder -> {
            builder.setParentVertex(SavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString()));
        });
    }

    private SavedBayesNet.NamedParam getParam(String str, long[] jArr) {
        return getParam(str, builder -> {
            builder.setLongArrayParam(SavedBayesNet.LongArray.newBuilder().addAllValues(Longs.asList(jArr)));
        });
    }

    private SavedBayesNet.NamedParam getParam(String str, int[] iArr) {
        return getParam(str, builder -> {
            builder.setIntArrayParam(SavedBayesNet.IntArray.newBuilder().addAllValues(Ints.asList(iArr)));
        });
    }

    private SavedBayesNet.NamedParam getParam(String str, Vertex[] vertexArr) {
        SavedBayesNet.VertexArray.Builder newBuilder = SavedBayesNet.VertexArray.newBuilder();
        for (Vertex vertex : vertexArr) {
            newBuilder.addValues(SavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString()));
        }
        return getParam(str, builder -> {
            builder.setVertexArrayParam(newBuilder.m1237build());
        });
    }

    private SavedBayesNet.DoubleTensor getTensor(DoubleTensor doubleTensor) {
        return SavedBayesNet.DoubleTensor.newBuilder().addAllValues(doubleTensor.asFlatList()).addAllShape(Longs.asList(doubleTensor.getShape())).m765build();
    }

    private SavedBayesNet.IntegerTensor getTensor(IntegerTensor integerTensor) {
        return SavedBayesNet.IntegerTensor.newBuilder().addAllValues(integerTensor.asFlatList()).addAllShape(Longs.asList(integerTensor.getShape())).m954build();
    }

    private SavedBayesNet.BooleanTensor getTensor(BooleanTensor booleanTensor) {
        return SavedBayesNet.BooleanTensor.newBuilder().addAllValues(booleanTensor.asFlatList()).addAllShape(Longs.asList(booleanTensor.getShape())).m716build();
    }

    @Override // io.improbable.keanu.network.NetworkSaver
    public void saveValue(Vertex vertex) {
        if (vertex.hasValue()) {
            this.graphBuilder.addDefaultState(getValue(vertex, vertex.getValue().toString()));
        }
    }

    @Override // io.improbable.keanu.network.NetworkSaver
    public void saveValue(DoubleVertex doubleVertex) {
        if (doubleVertex.hasValue()) {
            this.graphBuilder.addDefaultState(getValue(doubleVertex));
        }
    }

    @Override // io.improbable.keanu.network.NetworkSaver
    public void saveValue(IntegerVertex integerVertex) {
        if (integerVertex.hasValue()) {
            this.graphBuilder.addDefaultState(getValue(integerVertex));
        }
    }

    @Override // io.improbable.keanu.network.NetworkSaver
    public void saveValue(BooleanVertex booleanVertex) {
        if (booleanVertex.hasValue()) {
            this.graphBuilder.addDefaultState(getValue(booleanVertex));
        }
    }

    private SavedBayesNet.StoredValue getValue(Vertex vertex, String str) {
        return getStoredValue(vertex, SavedBayesNet.VertexValue.newBuilder().setGenericVal(SavedBayesNet.GenericTensor.newBuilder().addAllShape(Longs.asList(vertex.getShape())).addValues(str).m813build()).m1331build());
    }

    private SavedBayesNet.StoredValue getValue(DoubleVertex doubleVertex) {
        return getStoredValue(doubleVertex, SavedBayesNet.VertexValue.newBuilder().setDoubleVal(getTensor(doubleVertex.getValue())).m1331build());
    }

    private SavedBayesNet.StoredValue getValue(IntegerVertex integerVertex) {
        return getStoredValue(integerVertex, SavedBayesNet.VertexValue.newBuilder().setIntVal(getTensor(integerVertex.getValue())).m1331build());
    }

    private SavedBayesNet.StoredValue getValue(BooleanVertex booleanVertex) {
        return getStoredValue(booleanVertex, SavedBayesNet.VertexValue.newBuilder().setBoolVal(getTensor(booleanVertex.getValue())).m1331build());
    }

    private SavedBayesNet.StoredValue getStoredValue(Vertex vertex, SavedBayesNet.VertexValue vertexValue) {
        return SavedBayesNet.StoredValue.newBuilder().setId(SavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString()).m1284build()).setValue(vertexValue).setIsObserved(vertex.isObserved()).m1143build();
    }
}
