/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary;

import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.NonSaveableVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;

public class DoubleBinaryOpLambda<A, B>
extends DoubleVertex
implements Differentiable,
NonProbabilistic<DoubleTensor>,
NonSaveableVertex {
    protected final Vertex<A> left;
    protected final Vertex<B> right;
    protected final BiFunction<A, B, DoubleTensor> op;
    protected final Function<Map<Vertex, PartialDerivative>, PartialDerivative> forwardModeAutoDiffLambda;
    protected final Function<PartialDerivative, Map<Vertex, PartialDerivative>> reverseModeAutoDiffLambda;

    public DoubleBinaryOpLambda(long[] shape, Vertex<A> left, Vertex<B> right, BiFunction<A, B, DoubleTensor> op, Function<Map<Vertex, PartialDerivative>, PartialDerivative> forwardModeAutoDiffLambda, Function<PartialDerivative, Map<Vertex, PartialDerivative>> reverseModeAutoDiffLambda) {
        super(shape);
        this.left = left;
        this.right = right;
        this.op = op;
        this.forwardModeAutoDiffLambda = forwardModeAutoDiffLambda;
        this.reverseModeAutoDiffLambda = reverseModeAutoDiffLambda;
        this.setParents(left, right);
    }

    public DoubleBinaryOpLambda(long[] shape, Vertex<A> left, Vertex<B> right, BiFunction<A, B, DoubleTensor> op) {
        this(shape, left, right, op, null, null);
    }

    public DoubleBinaryOpLambda(Vertex<A> left, Vertex<B> right, BiFunction<A, B, DoubleTensor> op, Function<Map<Vertex, PartialDerivative>, PartialDerivative> forwardModeAutoDiffLambda, Function<PartialDerivative, Map<Vertex, PartialDerivative>> reverseModeAutoDiffLambda) {
        this(TensorShapeValidation.checkHasOneNonLengthOneShapeOrAllLengthOne(left.getShape(), right.getShape()), left, right, op, forwardModeAutoDiffLambda, reverseModeAutoDiffLambda);
    }

    public DoubleBinaryOpLambda(Vertex<A> left, Vertex<B> right, BiFunction<A, B, DoubleTensor> op) {
        this(TensorShapeValidation.checkHasOneNonLengthOneShapeOrAllLengthOne(left.getShape(), right.getShape()), left, right, op, null, null);
    }

    @Override
    public DoubleTensor calculate() {
        return this.op.apply(this.left.getValue(), this.right.getValue());
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        if (this.forwardModeAutoDiffLambda != null) {
            return this.forwardModeAutoDiffLambda.apply(derivativeOfParentsWithRespectToInput);
        }
        throw new UnsupportedOperationException();
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        return this.reverseModeAutoDiffLambda.apply(derivativeOfOutputWithRespectToSelf);
    }
}

