/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.AutoDiffBroadcast;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DoubleBinaryOpVertex;
import java.util.HashMap;
import java.util.Map;

public class PowerVertex
extends DoubleBinaryOpVertex
implements Differentiable {
    private static final String BASE_NAME = "left";
    private static final String EXPONENT_NAME = "right";

    @ExportVertexToPythonBindings
    public PowerVertex(@LoadVertexParam(value="left") DoubleVertex base, @LoadVertexParam(value="right") DoubleVertex exponent) {
        super(base, exponent);
    }

    public DoubleVertex getBase() {
        return super.getLeft();
    }

    public DoubleVertex getExponent() {
        return super.getRight();
    }

    @Override
    protected DoubleTensor op(DoubleTensor base, DoubleTensor exponent) {
        return (DoubleTensor)base.pow(exponent);
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        PartialDerivative dBaseWrtInput = derivativeOfParentsWithRespectToInput.getOrDefault(this.left, PartialDerivative.EMPTY);
        PartialDerivative dExponentWrtInput = derivativeOfParentsWithRespectToInput.getOrDefault(this.right, PartialDerivative.EMPTY);
        PartialDerivative fromBase = AutoDiffBroadcast.correctForBroadcastPartialForward(dBaseWrtInput, this.left.getShape(), this.getShape());
        PartialDerivative fromExponent = AutoDiffBroadcast.correctForBroadcastPartialForward(dExponentWrtInput, this.right.getShape(), this.getShape());
        PartialDerivative partialsFromBase = fromBase.isPresent() ? fromBase.multiplyAlongOfDimensions((DoubleTensor)((DoubleTensor)this.right.getValue()).times(((DoubleTensor)this.left.getValue()).pow(((DoubleTensor)this.right.getValue()).minus(1.0)))) : PartialDerivative.EMPTY;
        PartialDerivative partialsFromExponent = fromExponent.isPresent() ? fromExponent.multiplyAlongOfDimensions((DoubleTensor)((DoubleTensor)((DoubleTensor)this.left.getValue()).log()).timesInPlace((NumberTensor)this.getValue())) : PartialDerivative.EMPTY;
        return partialsFromBase.add(partialsFromExponent);
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        HashMap<Vertex, PartialDerivative> partials = new HashMap<Vertex, PartialDerivative>();
        DoubleTensor baseValue = (DoubleTensor)this.getBase().getValue();
        DoubleTensor exponentValue = (DoubleTensor)this.getExponent().getValue();
        DoubleTensor basePowExponent = (DoubleTensor)this.getValue();
        DoubleTensor dSelfWrtBase = (DoubleTensor)exponentValue.times(baseValue.pow(exponentValue.minus(1.0)));
        DoubleTensor dSelfWrtExponent = (DoubleTensor)basePowExponent.times(baseValue.log());
        PartialDerivative dOutputsWrtBase = derivativeOfOutputWithRespectToSelf.multiplyAlongWrtDimensions(dSelfWrtBase);
        PartialDerivative dOutputsWrtExponent = derivativeOfOutputWithRespectToSelf.multiplyAlongWrtDimensions(dSelfWrtExponent);
        PartialDerivative toBase = AutoDiffBroadcast.correctForBroadcastPartialReverse(dOutputsWrtBase, this.getShape(), this.getBase().getShape());
        PartialDerivative toExponent = AutoDiffBroadcast.correctForBroadcastPartialReverse(dOutputsWrtExponent, this.getShape(), this.getExponent().getShape());
        partials.put(this.getBase(), toBase);
        partials.put(this.getExponent(), toExponent);
        return partials;
    }
}

