package io.improbable.keanu.vertices.dbl.nonprobabilistic;

import com.google.common.collect.Iterables;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadShape;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.ProxyVertex;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.VertexLabel;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.util.Collections;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/DoubleProxyVertex.class */
public class DoubleProxyVertex extends DoubleVertex implements Differentiable, ProxyVertex<DoubleVertex>, NonProbabilistic<DoubleTensor> {
    private static final String LABEL_PARAM_NAME = "label";
    private static final String PARENT_NAME = "parent";

    public DoubleProxyVertex(VertexLabel vertexLabel) {
        this(Tensor.SCALAR_SHAPE, vertexLabel);
    }

    @ExportVertexToPythonBindings
    public DoubleProxyVertex(long[] jArr, VertexLabel vertexLabel) {
        super(jArr);
        setLabel(vertexLabel);
    }

    public DoubleProxyVertex(@LoadShape long[] jArr, @LoadVertexParam("label") String str, @LoadVertexParam(value = "parent", isNullable = true) DoubleVertex doubleVertex) {
        super(jArr);
        setLabel(VertexLabel.parseLabel(str));
        if (doubleVertex != null) {
            setParent(doubleVertex);
        }
    }

    @Override // io.improbable.keanu.vertices.Vertex
    public <V extends Vertex<DoubleTensor>> V setLabel(VertexLabel vertexLabel) {
        if (getLabel() == null || getLabel().getUnqualifiedName().equals(vertexLabel.getUnqualifiedName())) {
            return (V) super.setLabel(vertexLabel);
        }
        throw new RuntimeException("You should not change the label on a Proxy Vertex");
    }

    public DoubleProxyVertex(long[] jArr, String str) {
        this(jArr, new VertexLabel(str));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.vertices.NonProbabilistic
    public DoubleTensor calculate() {
        return getParent().getValue();
    }

    /* JADX WARN: Type inference failed for: r1v1, types: [long[], long[][]] */
    @Override // io.improbable.keanu.vertices.ProxyVertex
    public void setParent(DoubleVertex doubleVertex) {
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(getShape(), new long[]{doubleVertex.getShape()});
        setParents(doubleVertex);
    }

    @SaveVertexParam(value = PARENT_NAME, isNullable = true)
    public DoubleVertex getParent() {
        return (DoubleVertex) Iterables.getOnlyElement(getParents(), (Object) null);
    }

    @Override // io.improbable.keanu.vertices.ProxyVertex
    public boolean hasParent() {
        return !getParents().isEmpty();
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> map) {
        return map.get(getParent());
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative partialDerivative) {
        return Collections.singletonMap(getParent(), partialDerivative);
    }

    @SaveVertexParam(LABEL_PARAM_NAME)
    public String getLabelParameter() {
        return getLabel().toString();
    }
}
