/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.DoubleUnaryOpVertex;
import java.util.HashMap;
import java.util.Map;

public class MatrixInverseVertex
extends DoubleUnaryOpVertex
implements Differentiable {
    @ExportVertexToPythonBindings
    public MatrixInverseVertex(@LoadVertexParam(value="inputVertex") DoubleVertex inputVertex) {
        super(MatrixInverseVertex.checkInputIsSquareMatrix(inputVertex.getShape()), inputVertex);
    }

    @Override
    protected DoubleTensor op(DoubleTensor value) {
        return (DoubleTensor)value.matrixInverse();
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        PartialDerivative derivativeOfParentWithRespectToInputs = derivativeOfParentsWithRespectToInput.get(this.inputVertex);
        DoubleTensor negatedValue = (DoubleTensor)((DoubleTensor)this.getValue()).unaryMinus();
        PartialDerivative partial = PartialDerivative.matrixMultiplyAlongOfDimensions(derivativeOfParentWithRespectToInputs, negatedValue, false);
        partial = PartialDerivative.matrixMultiplyAlongOfDimensions(partial, (DoubleTensor)this.getValue(), true);
        return partial;
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        HashMap<Vertex, PartialDerivative> partials = new HashMap<Vertex, PartialDerivative>();
        DoubleTensor parentValue = (DoubleTensor)this.getValue();
        DoubleTensor negativeValue = (DoubleTensor)((DoubleTensor)this.getValue()).unaryMinus();
        PartialDerivative newPartials = PartialDerivative.matrixMultiplyAlongWrtDimensions(derivativeOfOutputWithRespectToSelf, negativeValue, false);
        newPartials = PartialDerivative.matrixMultiplyAlongWrtDimensions(newPartials, parentValue, true);
        partials.put(this.inputVertex, newPartials);
        return partials;
    }

    private static long[] checkInputIsSquareMatrix(long[] shape) {
        if (shape.length != 2) {
            throw new IllegalArgumentException("Can only invert a Matrix (received rank: " + shape.length + ")");
        }
        if (shape[0] != shape[1]) {
            throw new IllegalArgumentException("Can only invert a square Matrix (received: " + shape[0] + ", " + shape[1] + ")");
        }
        return shape;
    }
}

