/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.DoubleUnaryOpVertex;
import java.util.Collections;
import java.util.Map;

public class MatrixDeterminantVertex
extends DoubleUnaryOpVertex
implements Differentiable {
    @ExportVertexToPythonBindings
    public MatrixDeterminantVertex(@LoadVertexParam(value="inputVertex") DoubleVertex vertex) {
        super(Tensor.SCALAR_SHAPE, vertex);
        TensorShapeValidation.checkShapeIsSquareMatrix(vertex.getShape());
    }

    @Override
    protected DoubleTensor op(DoubleTensor value) {
        return DoubleTensor.scalar((Double)value.determinant());
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        throw new UnsupportedOperationException();
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        PartialDerivative dOutputTimesDeterminant = derivativeOfOutputWithRespectToSelf.multiplyBy((Double)((DoubleTensor)this.inputVertex.getValue()).determinant());
        long[] resultShape = TensorShape.concat(derivativeOfOutputWithRespectToSelf.get().getShape(), this.inputVertex.getShape());
        DoubleTensor reshapedPartial = MatrixDeterminantVertex.increaseRankByAppendingOnesToShape(dOutputTimesDeterminant.get(), resultShape.length);
        DoubleTensor broadcastedPartial = (DoubleTensor)reshapedPartial.broadcast(resultShape);
        DoubleTensor inverseTranspose = (DoubleTensor)((DoubleTensor)((DoubleTensor)this.inputVertex.getValue()).transpose()).matrixInverse();
        PartialDerivative toInput = new PartialDerivative(broadcastedPartial).multiplyAlongWrtDimensions(inverseTranspose);
        return Collections.singletonMap(this.inputVertex, toInput);
    }

    private static DoubleTensor increaseRankByAppendingOnesToShape(DoubleTensor lowRankTensor, int desiredRank) {
        return (DoubleTensor)lowRankTensor.reshape(TensorShape.shapeDesiredToRankByAppendingOnes(lowRankTensor.getShape(), desiredRank));
    }
}

