package io.improbable.keanu.util.io;

import com.google.common.primitives.Booleans;
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import com.google.gson.internal.Primitives;
import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.network.NetworkLoader;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.ProxyVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.VertexLabel;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import io.improbable.mir.KeanuSavedBayesNet;
import io.improbable.mir.SavedBayesNet;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.Parameter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/util/io/ProtobufLoader.class */
public class ProtobufLoader implements NetworkLoader {
    private final Map<Vertex, SavedBayesNet.StoredValue> savedValues = new HashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.improbable.keanu.util.io.ProtobufLoader$1, reason: invalid class name */
    /* loaded from: input_file:io/improbable/keanu/util/io/ProtobufLoader$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$io$improbable$mir$SavedBayesNet$NamedParam$ParamCase = new int[SavedBayesNet.NamedParam.ParamCase.values().length];

        static {
            try {
                $SwitchMap$io$improbable$mir$SavedBayesNet$NamedParam$ParamCase[SavedBayesNet.NamedParam.ParamCase.PARENT_VERTEX.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$io$improbable$mir$SavedBayesNet$NamedParam$ParamCase[SavedBayesNet.NamedParam.ParamCase.DOUBLE_TENSOR_PARAM.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$io$improbable$mir$SavedBayesNet$NamedParam$ParamCase[SavedBayesNet.NamedParam.ParamCase.INT_TENSOR_PARAM.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$io$improbable$mir$SavedBayesNet$NamedParam$ParamCase[SavedBayesNet.NamedParam.ParamCase.BOOL_TENSOR_PARAM.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$io$improbable$mir$SavedBayesNet$NamedParam$ParamCase[SavedBayesNet.NamedParam.ParamCase.DOUBLE_PARAM.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$io$improbable$mir$SavedBayesNet$NamedParam$ParamCase[SavedBayesNet.NamedParam.ParamCase.INT_PARAM.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$io$improbable$mir$SavedBayesNet$NamedParam$ParamCase[SavedBayesNet.NamedParam.ParamCase.LONG_PARAM.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$io$improbable$mir$SavedBayesNet$NamedParam$ParamCase[SavedBayesNet.NamedParam.ParamCase.STRING_PARAM.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$io$improbable$mir$SavedBayesNet$NamedParam$ParamCase[SavedBayesNet.NamedParam.ParamCase.BOOL_PARAM.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$io$improbable$mir$SavedBayesNet$NamedParam$ParamCase[SavedBayesNet.NamedParam.ParamCase.LONG_ARRAY_PARAM.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$io$improbable$mir$SavedBayesNet$NamedParam$ParamCase[SavedBayesNet.NamedParam.ParamCase.INT_ARRAY_PARAM.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$io$improbable$mir$SavedBayesNet$NamedParam$ParamCase[SavedBayesNet.NamedParam.ParamCase.VERTEX_ARRAY_PARAM.ordinal()] = 12;
            } catch (NoSuchFieldError e12) {
            }
        }
    }

    @Override // io.improbable.keanu.network.NetworkLoader
    public void loadValue(Vertex vertex) {
        throw new IllegalArgumentException("Cannot Load value for Untyped Vertex");
    }

    @Override // io.improbable.keanu.network.NetworkLoader
    public BayesianNetwork loadNetwork(InputStream inputStream) throws IOException {
        return loadNetwork(KeanuSavedBayesNet.ProtoModel.parseFrom(inputStream));
    }

    public BayesianNetwork loadNetwork(KeanuSavedBayesNet.ProtoModel protoModel) {
        return loadNetwork(protoModel.getGraph());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BayesianNetwork loadNetwork(SavedBayesNet.Graph graph) {
        HashMap hashMap = new HashMap();
        for (SavedBayesNet.Vertex vertex : graph.getVerticesList()) {
            hashMap.put(vertex.getId(), createVertexFromProtoBuf(vertex, hashMap));
        }
        BayesianNetwork bayesianNetwork = new BayesianNetwork(hashMap.values());
        loadDefaultValues(graph.getDefaultStateList(), hashMap, bayesianNetwork);
        return bayesianNetwork;
    }

    @Override // io.improbable.keanu.network.NetworkLoader
    public void loadValue(DoubleVertex doubleVertex) {
        SavedBayesNet.StoredValue storedValue = this.savedValues.get(doubleVertex);
        setOrObserveValue(doubleVertex, extractDoubleValue(storedValue.getValue()), storedValue.getIsObserved());
    }

    private DoubleTensor extractDoubleValue(SavedBayesNet.VertexValue vertexValue) {
        if (vertexValue.getValueTypeCase() != SavedBayesNet.VertexValue.ValueTypeCase.DOUBLE_VAL) {
            throw new IllegalArgumentException("Non Double Value specified for Double Vertex");
        }
        return extractDoubleTensor(vertexValue.getDoubleVal());
    }

    private void loadDefaultValues(List<SavedBayesNet.StoredValue> list, Map<SavedBayesNet.VertexID, Vertex> map, BayesianNetwork bayesianNetwork) {
        for (SavedBayesNet.StoredValue storedValue : list) {
            Vertex targetVertex = getTargetVertex(storedValue, map, bayesianNetwork);
            this.savedValues.put(targetVertex, storedValue);
            targetVertex.loadValue(this);
        }
    }

    private Vertex getTargetVertex(SavedBayesNet.StoredValue storedValue, Map<SavedBayesNet.VertexID, Vertex> map, BayesianNetwork bayesianNetwork) {
        return checkTargetsAreValid(getTargetByID(storedValue, map), getTargetByLabel(storedValue, map, bayesianNetwork), storedValue);
    }

    private Vertex getTargetByID(SavedBayesNet.StoredValue storedValue, Map<SavedBayesNet.VertexID, Vertex> map) {
        if (storedValue.hasId()) {
            return map.get(storedValue.getId());
        }
        return null;
    }

    private Vertex getTargetByLabel(SavedBayesNet.StoredValue storedValue, Map<SavedBayesNet.VertexID, Vertex> map, BayesianNetwork bayesianNetwork) {
        if (storedValue.getVertexLabel().isEmpty()) {
            return null;
        }
        return bayesianNetwork.getVertexByLabel(new VertexLabel(storedValue.getVertexLabel()));
    }

    private Vertex checkTargetsAreValid(Vertex vertex, Vertex vertex2, SavedBayesNet.StoredValue storedValue) {
        Vertex vertex3;
        if (vertex == null || vertex2 == null) {
            if (vertex == null && vertex2 == null) {
                throw new IllegalArgumentException("Value specified for unknown Vertex: (" + storedValue.getVertexLabel() + ") (" + storedValue.getId() + ")");
            }
            vertex3 = vertex != null ? vertex : vertex2;
        } else {
            if (vertex != vertex2) {
                throw new IllegalArgumentException("Label and VertexID don't refer to same Vertex: (" + storedValue.getVertexLabel() + ") (" + storedValue.getId().toString() + ")");
            }
            vertex3 = vertex;
        }
        return vertex3;
    }

    @Override // io.improbable.keanu.network.NetworkLoader
    public void loadValue(BooleanVertex booleanVertex) {
        SavedBayesNet.StoredValue storedValue = this.savedValues.get(booleanVertex);
        setOrObserveValue(booleanVertex, extractBoolValue(storedValue.getValue()), storedValue.getIsObserved());
    }

    private BooleanTensor extractBoolValue(SavedBayesNet.VertexValue vertexValue) {
        if (vertexValue.getValueTypeCase() != SavedBayesNet.VertexValue.ValueTypeCase.BOOL_VAL) {
            throw new IllegalArgumentException("Non Boolean Value specified for Boolean Vertex");
        }
        return extractBoolTensor(vertexValue.getBoolVal());
    }

    @Override // io.improbable.keanu.network.NetworkLoader
    public void loadValue(IntegerVertex integerVertex) {
        SavedBayesNet.StoredValue storedValue = this.savedValues.get(integerVertex);
        setOrObserveValue(integerVertex, extractIntValue(storedValue.getValue()), storedValue.getIsObserved());
    }

    private IntegerTensor extractIntValue(SavedBayesNet.VertexValue vertexValue) {
        if (vertexValue.getValueTypeCase() != SavedBayesNet.VertexValue.ValueTypeCase.INT_VAL) {
            throw new IllegalArgumentException("Non Int Value specified for Int Vertex");
        }
        return extractIntTensor(vertexValue.getIntVal());
    }

    private void setOrObserveValue(Vertex vertex, Tensor tensor, boolean z) {
        if (z) {
            vertex.observe(tensor);
        } else {
            vertex.setValue(tensor);
        }
    }

    private <T> Vertex<T> createVertexFromProtoBuf(SavedBayesNet.Vertex vertex, Map<SavedBayesNet.VertexID, Vertex> map) {
        try {
            Vertex<T> instantiateVertex = instantiateVertex(Class.forName(vertex.getVertexType()), getParameterMap(vertex, map), vertex);
            if (!vertex.getLabel().isEmpty() && !(instantiateVertex instanceof ProxyVertex)) {
                instantiateVertex.setLabel(VertexLabel.parseLabel(vertex.getLabel()));
            }
            return instantiateVertex;
        } catch (ClassNotFoundException e) {
            throw new IllegalArgumentException("Unknown Vertex Type Specified: " + vertex.getVertexType(), e);
        }
    }

    private Vertex instantiateVertex(Class cls, Map<String, Object> map, SavedBayesNet.Vertex vertex) {
        Constructor annotatedConstructor = getAnnotatedConstructor(cls);
        Parameter[] parameters = annotatedConstructor.getParameters();
        Object[] objArr = new Object[parameters.length];
        for (int i = 0; i < parameters.length; i++) {
            Object parameter = getParameter(parameters[i], map, vertex);
            objArr[i] = parameter;
            Class<?> wrap = Primitives.wrap(parameters[i].getType());
            Class<?> cls2 = parameter != null ? objArr[i].getClass() : wrap;
            if (!wrap.isAssignableFrom(cls2)) {
                throw new IllegalArgumentException("Incorrect Parameter Type specified.  Got: " + cls2 + ", Expected: " + wrap);
            }
        }
        try {
            return (Vertex) annotatedConstructor.newInstance(objArr);
        } catch (Exception e) {
            throw new IllegalArgumentException("Failed to create new Vertex", e);
        }
    }

    private Object getParameter(Parameter parameter, Map<String, Object> map, SavedBayesNet.Vertex vertex) {
        LoadVertexParam loadVertexParam = (LoadVertexParam) parameter.getAnnotation(LoadVertexParam.class);
        if (loadVertexParam == null) {
            if (parameter.getAnnotation(LoadShape.class) != null) {
                return vertex.getShapeCount() == 0 ? Tensor.SCALAR_SHAPE : Longs.toArray(vertex.getShapeList());
            }
            throw new IllegalArgumentException("Cannot create Vertex due to unannotated parameter in constructor");
        }
        Object obj = map.get(loadVertexParam.value());
        if (obj != null || loadVertexParam.isNullable()) {
            return obj;
        }
        throw new IllegalArgumentException("Failed to create vertex due to missing parameter: " + loadVertexParam.value());
    }

    private Constructor getAnnotatedConstructor(Class cls) {
        for (Constructor<?> constructor : cls.getConstructors()) {
            Parameter[] parameters = constructor.getParameters();
            if (parameters.length > 0 && (parameters[0].isAnnotationPresent(LoadVertexParam.class) || parameters[0].isAnnotationPresent(LoadShape.class))) {
                return constructor;
            }
        }
        throw new IllegalArgumentException("No Annotated Load Constructor for Vertex of type: " + cls);
    }

    private Map<String, Object> getParameterMap(SavedBayesNet.Vertex vertex, Map<SavedBayesNet.VertexID, Vertex> map) {
        HashMap hashMap = new HashMap();
        for (SavedBayesNet.NamedParam namedParam : vertex.getParametersList()) {
            hashMap.put(namedParam.getName(), getDecodedParam(namedParam, map));
        }
        return hashMap;
    }

    private Object getDecodedParam(SavedBayesNet.NamedParam namedParam, Map<SavedBayesNet.VertexID, Vertex> map) {
        switch (AnonymousClass1.$SwitchMap$io$improbable$mir$SavedBayesNet$NamedParam$ParamCase[namedParam.getParamCase().ordinal()]) {
            case 1:
                return map.get(namedParam.getParentVertex());
            case 2:
                return extractDoubleTensor(namedParam.getDoubleTensorParam());
            case 3:
                return extractIntTensor(namedParam.getIntTensorParam());
            case 4:
                return extractBoolTensor(namedParam.getBoolTensorParam());
            case 5:
                return Double.valueOf(namedParam.getDoubleParam());
            case 6:
                return Integer.valueOf(namedParam.getIntParam());
            case 7:
                return Long.valueOf(namedParam.getLongParam());
            case 8:
                return namedParam.getStringParam();
            case SavedBayesNet.NamedParam.STRING_PARAM_FIELD_NUMBER /* 9 */:
                return Boolean.valueOf(namedParam.getBoolParam());
            case SavedBayesNet.NamedParam.LONG_ARRAY_PARAM_FIELD_NUMBER /* 10 */:
                return Longs.toArray(namedParam.getLongArrayParam().getValuesList());
            case SavedBayesNet.NamedParam.VERTEX_ARRAY_PARAM_FIELD_NUMBER /* 11 */:
                return Ints.toArray(namedParam.getIntArrayParam().getValuesList());
            case SavedBayesNet.NamedParam.INT_ARRAY_PARAM_FIELD_NUMBER /* 12 */:
                return extractVertexArray(namedParam, map);
            default:
                throw new IllegalArgumentException("Unknown Param Type Received: " + namedParam.getParamCase().toString());
        }
    }

    private Vertex[] extractVertexArray(SavedBayesNet.NamedParam namedParam, Map<SavedBayesNet.VertexID, Vertex> map) {
        Vertex[] vertexArr = new Vertex[namedParam.getVertexArrayParam().getValuesCount()];
        for (int i = 0; i < vertexArr.length; i++) {
            SavedBayesNet.VertexID values = namedParam.getVertexArrayParam().getValues(i);
            Vertex vertex = map.get(values);
            if (vertex == null) {
                throw new IllegalArgumentException("Saved Structure references unknown Parent: " + values.toString());
            }
            vertexArr[i] = vertex;
        }
        return vertexArr;
    }

    private DoubleTensor extractDoubleTensor(SavedBayesNet.DoubleTensor doubleTensor) {
        return DoubleTensor.create(Doubles.toArray(doubleTensor.getValuesList()), Longs.toArray(doubleTensor.getShapeList()));
    }

    private IntegerTensor extractIntTensor(SavedBayesNet.IntegerTensor integerTensor) {
        return IntegerTensor.create(Ints.toArray(integerTensor.getValuesList()), Longs.toArray(integerTensor.getShapeList()));
    }

    private BooleanTensor extractBoolTensor(SavedBayesNet.BooleanTensor booleanTensor) {
        return BooleanTensor.create(Booleans.toArray(booleanTensor.getValuesList()), Longs.toArray(booleanTensor.getShapeList()));
    }
}
