package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary;

import io.improbable.keanu.annotation.DisplayInformationForOutput;
import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.util.csv.Writer;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.AutoDiffBroadcast;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.util.HashMap;
import java.util.Map;

@DisplayInformationForOutput(displayName = Writer.DEFAULT_EMPTY_VALUE)
/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/operators/binary/DifferenceVertex.class */
public class DifferenceVertex extends DoubleBinaryOpVertex implements Differentiable {
    @ExportVertexToPythonBindings
    public DifferenceVertex(@LoadVertexParam("left") DoubleVertex doubleVertex, @LoadVertexParam("right") DoubleVertex doubleVertex2) {
        super(TensorShapeValidation.checkIsBroadcastable(doubleVertex.getShape(), doubleVertex2.getShape()), doubleVertex, doubleVertex2);
    }

    @Override // io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DoubleBinaryOpVertex
    protected DoubleTensor op(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        return (DoubleTensor) doubleTensor.minus(doubleTensor2);
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> map) {
        return AutoDiffBroadcast.correctForBroadcastPartialForward(map.getOrDefault(this.left, PartialDerivative.EMPTY), this.left.getShape(), getShape()).subtract(AutoDiffBroadcast.correctForBroadcastPartialForward(map.getOrDefault(this.right, PartialDerivative.EMPTY), this.right.getShape(), getShape()));
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative partialDerivative) {
        HashMap hashMap = new HashMap();
        PartialDerivative correctForBroadcastPartialReverse = AutoDiffBroadcast.correctForBroadcastPartialReverse(partialDerivative, getShape(), this.left.getShape());
        PartialDerivative correctForBroadcastPartialReverse2 = AutoDiffBroadcast.correctForBroadcastPartialReverse(partialDerivative.multiplyBy(-1.0d), getShape(), this.right.getShape());
        hashMap.put(this.left, correctForBroadcastPartialReverse);
        hashMap.put(this.right, correctForBroadcastPartialReverse2);
        return hashMap;
    }
}
