package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.util.Collections;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/operators/unary/MatrixDeterminantVertex.class */
public class MatrixDeterminantVertex extends DoubleUnaryOpVertex implements Differentiable {
    @ExportVertexToPythonBindings
    public MatrixDeterminantVertex(@LoadVertexParam("inputVertex") DoubleVertex doubleVertex) {
        super(Tensor.SCALAR_SHAPE, doubleVertex);
        TensorShapeValidation.checkShapeIsSquareMatrix(doubleVertex.getShape());
    }

    @Override // io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.DoubleUnaryOpVertex
    protected DoubleTensor op(DoubleTensor doubleTensor) {
        return DoubleTensor.scalar(doubleTensor.determinant().doubleValue());
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> map) {
        throw new UnsupportedOperationException();
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative partialDerivative) {
        PartialDerivative multiplyBy = partialDerivative.multiplyBy(this.inputVertex.getValue().determinant().doubleValue());
        long[] concat = TensorShape.concat(partialDerivative.get().getShape(), this.inputVertex.getShape());
        DoubleTensor doubleTensor = (DoubleTensor) increaseRankByAppendingOnesToShape(multiplyBy.get(), concat.length).broadcast(concat);
        return Collections.singletonMap(this.inputVertex, new PartialDerivative(doubleTensor).multiplyAlongWrtDimensions(((DoubleTensor) this.inputVertex.getValue().transpose()).matrixInverse()));
    }

    private static DoubleTensor increaseRankByAppendingOnesToShape(DoubleTensor doubleTensor, int i) {
        return (DoubleTensor) doubleTensor.reshape(TensorShape.shapeDesiredToRankByAppendingOnes(doubleTensor.getShape(), i));
    }
}
