/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.AutoDiffBroadcast;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DoubleBinaryOpVertex;
import java.util.HashMap;
import java.util.Map;

public class ArcTan2Vertex
extends DoubleBinaryOpVertex
implements Differentiable {
    private static final String X_NAME = "left";
    private static final String Y_NAME = "right";

    @ExportVertexToPythonBindings
    public ArcTan2Vertex(@LoadVertexParam(value="left") DoubleVertex x, @LoadVertexParam(value="right") DoubleVertex y) {
        super(x, y);
    }

    @Override
    protected DoubleTensor op(DoubleTensor x, DoubleTensor y) {
        return x.atan2(y);
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        PartialDerivative dxWrtInput = derivativeOfParentsWithRespectToInput.getOrDefault(this.left, PartialDerivative.EMPTY);
        PartialDerivative dyWrtInput = derivativeOfParentsWithRespectToInput.getOrDefault(this.right, PartialDerivative.EMPTY);
        DoubleTensor yValue = (DoubleTensor)this.right.getValue();
        DoubleTensor xValue = (DoubleTensor)this.left.getValue();
        DoubleTensor denominator = yValue.pow(2.0).plusInPlace(xValue.pow(2.0));
        PartialDerivative fromX = AutoDiffBroadcast.correctForBroadcastPartialForward(dxWrtInput, this.left.getShape(), this.getShape());
        PartialDerivative fromY = AutoDiffBroadcast.correctForBroadcastPartialForward(dyWrtInput, this.right.getShape(), this.getShape());
        PartialDerivative diffFromX = fromX.multiplyAlongOfDimensions((DoubleTensor)yValue.div(denominator).unaryMinusInPlace());
        PartialDerivative diffFromY = fromY.multiplyAlongOfDimensions(xValue.div(denominator));
        return diffFromX.add(diffFromY);
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        HashMap<Vertex, PartialDerivative> partials = new HashMap<Vertex, PartialDerivative>();
        DoubleTensor xValue = (DoubleTensor)this.left.getValue();
        DoubleTensor yValue = (DoubleTensor)this.right.getValue();
        DoubleTensor denominator = yValue.pow(2.0).plusInPlace(xValue.pow(2.0));
        DoubleTensor dOutWrtX = (DoubleTensor)yValue.div(denominator).unaryMinusInPlace();
        DoubleTensor dOutWrtY = xValue.div(denominator);
        PartialDerivative dOutputsWrtLeft = derivativeOfOutputWithRespectToSelf.multiplyAlongWrtDimensions(dOutWrtX);
        PartialDerivative dOutputsWrtRight = derivativeOfOutputWithRespectToSelf.multiplyAlongWrtDimensions(dOutWrtY);
        PartialDerivative toLeft = AutoDiffBroadcast.correctForBroadcastPartialReverse(dOutputsWrtLeft, this.getShape(), this.left.getShape());
        PartialDerivative toRight = AutoDiffBroadcast.correctForBroadcastPartialReverse(dOutputsWrtRight, this.getShape(), this.right.getShape());
        partials.put(this.left, toLeft);
        partials.put(this.right, toRight);
        return partials;
    }
}

