/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DoubleBinaryOpVertex;
import java.util.HashMap;
import java.util.Map;

public class MatrixMultiplicationVertex
extends DoubleBinaryOpVertex
implements Differentiable {
    @ExportVertexToPythonBindings
    public MatrixMultiplicationVertex(@LoadVertexParam(value="left") DoubleVertex left, @LoadVertexParam(value="right") DoubleVertex right) {
        super(TensorShapeValidation.getMatrixMultiplicationResultingShape(left.getShape(), right.getShape()), left, right);
    }

    @Override
    protected DoubleTensor op(DoubleTensor l, DoubleTensor r) {
        return l.matrixMultiply(r);
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        PartialDerivative dOutputsWrtLeft = PartialDerivative.matrixMultiplyAlongWrtDimensions(derivativeOfOutputWithRespectToSelf, (DoubleTensor)this.right.getValue(), true);
        PartialDerivative dOutputsWrtRight = PartialDerivative.matrixMultiplyAlongWrtDimensions(derivativeOfOutputWithRespectToSelf, (DoubleTensor)this.left.getValue(), false);
        HashMap<Vertex, PartialDerivative> partials = new HashMap<Vertex, PartialDerivative>();
        partials.put(this.left, dOutputsWrtLeft);
        partials.put(this.right, dOutputsWrtRight);
        return partials;
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        PartialDerivative dLeftWrtInput = derivativeOfParentsWithRespectToInput.getOrDefault(this.left, PartialDerivative.EMPTY);
        PartialDerivative dRightWrtInput = derivativeOfParentsWithRespectToInput.getOrDefault(this.right, PartialDerivative.EMPTY);
        PartialDerivative partialsFromLeft = PartialDerivative.matrixMultiplyAlongOfDimensions(dLeftWrtInput, (DoubleTensor)this.right.getValue(), true);
        PartialDerivative partialsFromRight = PartialDerivative.matrixMultiplyAlongOfDimensions(dRightWrtInput, (DoubleTensor)this.left.getValue(), false);
        return partialsFromLeft.add(partialsFromRight);
    }
}

