/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.util.io;

import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.network.NetworkSaver;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.NonSaveableVertex;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import io.improbable.mir.KeanuSavedBayesNet;
import io.improbable.mir.SavedBayesNet;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import org.apache.commons.math3.util.Pair;

public class ProtobufSaver
implements NetworkSaver {
    private final BayesianNetwork net;
    private SavedBayesNet.Graph.Builder graphBuilder = null;

    public ProtobufSaver(BayesianNetwork net) {
        this.net = net;
    }

    @Override
    public void save(OutputStream output, boolean saveValues, Map<String, String> metadata) throws IOException {
        KeanuSavedBayesNet.ProtoModel protobufModel = this.getModel(saveValues, metadata);
        protobufModel.writeTo(output);
        this.clearGraph();
    }

    protected KeanuSavedBayesNet.ProtoModel getModel(boolean withSavedValues, Map<String, String> metadata) {
        SavedBayesNet.Graph graph = this.getGraph(withSavedValues);
        KeanuSavedBayesNet.ProtoModel.Builder builder = KeanuSavedBayesNet.ProtoModel.newBuilder().setGraph(graph);
        if (metadata != null) {
            builder.setMetadata(this.buildMetadata(metadata));
        }
        return builder.build();
    }

    protected SavedBayesNet.Graph getGraph(boolean withSavedValues) {
        this.createGraph(withSavedValues);
        return this.graphBuilder.build();
    }

    protected void clearGraph() {
        this.graphBuilder = null;
    }

    private void createGraph(boolean saveValues) {
        this.graphBuilder = SavedBayesNet.Graph.newBuilder();
        this.net.save(this);
        if (saveValues) {
            this.net.saveValues(this);
        }
    }

    private KeanuSavedBayesNet.ModelMetadata buildMetadata(Map<String, String> metadata) {
        KeanuSavedBayesNet.ModelMetadata.Builder metadataBuilder = KeanuSavedBayesNet.ModelMetadata.newBuilder();
        Object[] metadataKeys = metadata.keySet().toArray(new String[0]);
        Arrays.sort(metadataKeys);
        for (Object metadataKey : metadataKeys) {
            metadataBuilder.putMetadataInfo((String)metadataKey, metadata.get(metadataKey));
        }
        return metadataBuilder.build();
    }

    @Override
    public void save(Vertex vertex) {
        if (vertex instanceof NonSaveableVertex) {
            throw new IllegalArgumentException("Trying to save a vertex that isn't Saveable");
        }
        this.graphBuilder.addVertices(this.buildVertex(vertex));
    }

    private SavedBayesNet.Vertex buildVertex(Vertex vertex) {
        SavedBayesNet.Vertex.Builder vertexBuilder = SavedBayesNet.Vertex.newBuilder();
        if (vertex.getLabel() != null) {
            vertexBuilder = vertexBuilder.setLabel(vertex.getLabel().toString());
        }
        vertexBuilder = vertexBuilder.setId(SavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString()));
        vertexBuilder = vertexBuilder.setVertexType(vertex.getClass().getCanonicalName());
        vertexBuilder = vertexBuilder.addAllShape(Longs.asList((long[])vertex.getShape()));
        this.saveParams(vertexBuilder, vertex);
        return vertexBuilder.build();
    }

    private void saveParams(SavedBayesNet.Vertex.Builder vertexBuilder, Vertex vertex) {
        Map<String, Pair<Method, Boolean>> parentRetrievalMethodMap = this.getParentRetrievalMethodMap(vertex);
        Object[] parentNames = parentRetrievalMethodMap.keySet().toArray(new String[0]);
        Arrays.sort(parentNames);
        for (Object parentName : parentNames) {
            boolean isParentNullable;
            Method getParentMethod = (Method)parentRetrievalMethodMap.get(parentName).getFirst();
            SavedBayesNet.NamedParam encodedParam = this.getEncodedParam(vertex, (String)parentName, getParentMethod, isParentNullable = ((Boolean)parentRetrievalMethodMap.get(parentName).getSecond()).booleanValue());
            if (encodedParam == null) continue;
            vertexBuilder.addParameters(encodedParam);
        }
    }

    private Map<String, Pair<Method, Boolean>> getParentRetrievalMethodMap(Vertex vertex) {
        Class<?> vertexClass = vertex.getClass();
        Method[] methods = vertexClass.getMethods();
        HashMap<String, Pair<Method, Boolean>> parentRetrievalMethodMap = new HashMap<String, Pair<Method, Boolean>>();
        for (Method method : methods) {
            SaveVertexParam vertexAnnotation = method.getAnnotation(SaveVertexParam.class);
            if (vertexAnnotation == null) continue;
            String parentName = vertexAnnotation.value();
            parentRetrievalMethodMap.put(parentName, (Pair<Method, Boolean>)new Pair((Object)method, (Object)vertexAnnotation.isNullable()));
        }
        return parentRetrievalMethodMap;
    }

    private SavedBayesNet.NamedParam getEncodedParam(Vertex vertex, String paramName, Method getParamMethod, boolean isNullable) {
        Object param;
        try {
            param = getParamMethod.invoke((Object)vertex, new Object[0]);
        }
        catch (Exception e) {
            throw new IllegalArgumentException("Invalid parent retrieval function specified", e);
        }
        if (param == null) {
            if (!isNullable) {
                throw new IllegalArgumentException("No value returned from Save Function");
            }
            return null;
        }
        return this.getTypedParam(paramName, param);
    }

    private SavedBayesNet.NamedParam getTypedParam(String paramName, Object param) {
        if (Vertex.class.isAssignableFrom(param.getClass())) {
            return this.getParam(paramName, (Vertex)param);
        }
        if (DoubleTensor.class.isAssignableFrom(param.getClass())) {
            return this.getParam(paramName, (SavedBayesNet.NamedParam.Builder builder) -> builder.setDoubleTensorParam(this.getTensor((DoubleTensor)param)));
        }
        if (IntegerTensor.class.isAssignableFrom(param.getClass())) {
            return this.getParam(paramName, (SavedBayesNet.NamedParam.Builder builder) -> builder.setIntTensorParam(this.getTensor((IntegerTensor)param)));
        }
        if (BooleanTensor.class.isAssignableFrom(param.getClass())) {
            return this.getParam(paramName, (SavedBayesNet.NamedParam.Builder builder) -> builder.setBoolTensorParam(this.getTensor((BooleanTensor)param)));
        }
        if (Double.class.isAssignableFrom(param.getClass())) {
            return this.getParam(paramName, (SavedBayesNet.NamedParam.Builder builder) -> builder.setDoubleParam((Double)param));
        }
        if (Integer.class.isAssignableFrom(param.getClass())) {
            return this.getParam(paramName, (SavedBayesNet.NamedParam.Builder builder) -> builder.setIntParam((Integer)param));
        }
        if (Long.class.isAssignableFrom(param.getClass())) {
            return this.getParam(paramName, (SavedBayesNet.NamedParam.Builder builder) -> builder.setLongParam((Long)param));
        }
        if (String.class.isAssignableFrom(param.getClass())) {
            return this.getParam(paramName, (SavedBayesNet.NamedParam.Builder builder) -> builder.setStringParam((String)param));
        }
        if (Boolean.class.isAssignableFrom(param.getClass())) {
            return this.getParam(paramName, (SavedBayesNet.NamedParam.Builder builder) -> builder.setBoolParam((Boolean)param));
        }
        if (Long[].class.isAssignableFrom(param.getClass())) {
            return this.getParam(paramName, (long[])param);
        }
        if (Vertex[].class.isAssignableFrom(param.getClass())) {
            return this.getParam(paramName, (Vertex[])param);
        }
        if (Integer[].class.isAssignableFrom(param.getClass())) {
            return this.getParam(paramName, (int[])param);
        }
        throw new IllegalArgumentException("Unknown Parameter Type to Save: " + param.getClass().toString());
    }

    private SavedBayesNet.NamedParam getParam(String paramName, Consumer<SavedBayesNet.NamedParam.Builder> valueSetter) {
        SavedBayesNet.NamedParam.Builder paramBuilder = SavedBayesNet.NamedParam.newBuilder();
        paramBuilder.setName(paramName);
        valueSetter.accept(paramBuilder);
        return paramBuilder.build();
    }

    private SavedBayesNet.NamedParam getParam(String paramName, Vertex parent) {
        return this.getParam(paramName, (SavedBayesNet.NamedParam.Builder builder) -> builder.setParentVertex(SavedBayesNet.VertexID.newBuilder().setId(parent.getId().toString())));
    }

    private SavedBayesNet.NamedParam getParam(String paramName, long[] param) {
        return this.getParam(paramName, (SavedBayesNet.NamedParam.Builder builder) -> builder.setLongArrayParam(SavedBayesNet.LongArray.newBuilder().addAllValues(Longs.asList((long[])param))));
    }

    private SavedBayesNet.NamedParam getParam(String paramName, int[] param) {
        return this.getParam(paramName, (SavedBayesNet.NamedParam.Builder builder) -> builder.setIntArrayParam(SavedBayesNet.IntArray.newBuilder().addAllValues(Ints.asList((int[])param))));
    }

    private SavedBayesNet.NamedParam getParam(String paramName, Vertex[] param) {
        SavedBayesNet.VertexArray.Builder vertexArray = SavedBayesNet.VertexArray.newBuilder();
        for (Vertex vertex : param) {
            vertexArray.addValues(SavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString()));
        }
        return this.getParam(paramName, (SavedBayesNet.NamedParam.Builder builder) -> builder.setVertexArrayParam(vertexArray.build()));
    }

    private SavedBayesNet.DoubleTensor getTensor(DoubleTensor tensor) {
        return SavedBayesNet.DoubleTensor.newBuilder().addAllValues(tensor.asFlatList()).addAllShape(Longs.asList((long[])tensor.getShape())).build();
    }

    private SavedBayesNet.IntegerTensor getTensor(IntegerTensor tensor) {
        return SavedBayesNet.IntegerTensor.newBuilder().addAllValues(tensor.asFlatList()).addAllShape(Longs.asList((long[])tensor.getShape())).build();
    }

    private SavedBayesNet.BooleanTensor getTensor(BooleanTensor tensor) {
        return SavedBayesNet.BooleanTensor.newBuilder().addAllValues(tensor.asFlatList()).addAllShape(Longs.asList((long[])tensor.getShape())).build();
    }

    @Override
    public void saveValue(Vertex vertex) {
        if (vertex.hasValue()) {
            SavedBayesNet.StoredValue value = this.getValue(vertex, vertex.getValue().toString());
            this.graphBuilder.addDefaultState(value);
        }
    }

    @Override
    public void saveValue(DoubleVertex vertex) {
        if (vertex.hasValue()) {
            SavedBayesNet.StoredValue value = this.getValue(vertex);
            this.graphBuilder.addDefaultState(value);
        }
    }

    @Override
    public void saveValue(IntegerVertex vertex) {
        if (vertex.hasValue()) {
            SavedBayesNet.StoredValue value = this.getValue(vertex);
            this.graphBuilder.addDefaultState(value);
        }
    }

    @Override
    public void saveValue(BooleanVertex vertex) {
        if (vertex.hasValue()) {
            SavedBayesNet.StoredValue value = this.getValue(vertex);
            this.graphBuilder.addDefaultState(value);
        }
    }

    private SavedBayesNet.StoredValue getValue(Vertex vertex, String formattedValue) {
        SavedBayesNet.GenericTensor savedValue = SavedBayesNet.GenericTensor.newBuilder().addAllShape(Longs.asList((long[])vertex.getShape())).addValues(formattedValue).build();
        SavedBayesNet.VertexValue value = SavedBayesNet.VertexValue.newBuilder().setGenericVal(savedValue).build();
        return this.getStoredValue(vertex, value);
    }

    private SavedBayesNet.StoredValue getValue(DoubleVertex vertex) {
        SavedBayesNet.DoubleTensor savedValue = this.getTensor((DoubleTensor)vertex.getValue());
        SavedBayesNet.VertexValue value = SavedBayesNet.VertexValue.newBuilder().setDoubleVal(savedValue).build();
        return this.getStoredValue(vertex, value);
    }

    private SavedBayesNet.StoredValue getValue(IntegerVertex vertex) {
        SavedBayesNet.IntegerTensor savedValue = this.getTensor((IntegerTensor)vertex.getValue());
        SavedBayesNet.VertexValue value = SavedBayesNet.VertexValue.newBuilder().setIntVal(savedValue).build();
        return this.getStoredValue(vertex, value);
    }

    private SavedBayesNet.StoredValue getValue(BooleanVertex vertex) {
        SavedBayesNet.BooleanTensor savedValue = this.getTensor((BooleanTensor)vertex.getValue());
        SavedBayesNet.VertexValue value = SavedBayesNet.VertexValue.newBuilder().setBoolVal(savedValue).build();
        return this.getStoredValue(vertex, value);
    }

    private SavedBayesNet.StoredValue getStoredValue(Vertex vertex, SavedBayesNet.VertexValue value) {
        return SavedBayesNet.StoredValue.newBuilder().setId(SavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString()).build()).setValue(value).setIsObserved(vertex.isObserved()).build();
    }
}

