package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.AutoDiffBroadcast;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/operators/binary/PowerVertex.class */
public class PowerVertex extends DoubleBinaryOpVertex implements Differentiable {
    private static final String BASE_NAME = "left";
    private static final String EXPONENT_NAME = "right";

    @ExportVertexToPythonBindings
    public PowerVertex(@LoadVertexParam("left") DoubleVertex doubleVertex, @LoadVertexParam("right") DoubleVertex doubleVertex2) {
        super(doubleVertex, doubleVertex2);
    }

    public DoubleVertex getBase() {
        return super.getLeft();
    }

    public DoubleVertex getExponent() {
        return super.getRight();
    }

    @Override // io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DoubleBinaryOpVertex
    protected DoubleTensor op(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        return (DoubleTensor) doubleTensor.pow(doubleTensor2);
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> map) {
        PartialDerivative orDefault = map.getOrDefault(this.left, PartialDerivative.EMPTY);
        PartialDerivative orDefault2 = map.getOrDefault(this.right, PartialDerivative.EMPTY);
        PartialDerivative correctForBroadcastPartialForward = AutoDiffBroadcast.correctForBroadcastPartialForward(orDefault, this.left.getShape(), getShape());
        PartialDerivative correctForBroadcastPartialForward2 = AutoDiffBroadcast.correctForBroadcastPartialForward(orDefault2, this.right.getShape(), getShape());
        return (correctForBroadcastPartialForward.isPresent() ? correctForBroadcastPartialForward.multiplyAlongOfDimensions((DoubleTensor) this.right.getValue().times((DoubleTensor) this.left.getValue().pow(this.right.getValue().minus2(1.0d)))) : PartialDerivative.EMPTY).add(correctForBroadcastPartialForward2.isPresent() ? correctForBroadcastPartialForward2.multiplyAlongOfDimensions((DoubleTensor) this.left.getValue().log2().timesInPlace(getValue())) : PartialDerivative.EMPTY);
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative partialDerivative) {
        HashMap hashMap = new HashMap();
        DoubleTensor value = getBase().getValue();
        DoubleTensor value2 = getExponent().getValue();
        DoubleTensor value3 = getValue();
        DoubleTensor doubleTensor = (DoubleTensor) value2.times((DoubleTensor) value.pow(value2.minus2(1.0d)));
        DoubleTensor doubleTensor2 = (DoubleTensor) value3.times(value.log2());
        PartialDerivative multiplyAlongWrtDimensions = partialDerivative.multiplyAlongWrtDimensions(doubleTensor);
        PartialDerivative multiplyAlongWrtDimensions2 = partialDerivative.multiplyAlongWrtDimensions(doubleTensor2);
        PartialDerivative correctForBroadcastPartialReverse = AutoDiffBroadcast.correctForBroadcastPartialReverse(multiplyAlongWrtDimensions, getShape(), getBase().getShape());
        PartialDerivative correctForBroadcastPartialReverse2 = AutoDiffBroadcast.correctForBroadcastPartialReverse(multiplyAlongWrtDimensions2, getShape(), getExponent().getShape());
        hashMap.put(getBase(), correctForBroadcastPartialReverse);
        hashMap.put(getExponent(), correctForBroadcastPartialReverse2);
        return hashMap;
    }
}
