package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary;

import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.NonSaveableVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.util.Map;
import java.util.function.Function;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/operators/unary/DoubleUnaryOpLambda.class */
public class DoubleUnaryOpLambda<IN> extends DoubleVertex implements Differentiable, NonProbabilistic<DoubleTensor>, NonSaveableVertex {
    private final Vertex<IN> inputVertex;
    private final Function<IN, DoubleTensor> op;
    private final Function<Map<Vertex, PartialDerivative>, PartialDerivative> forwardModeAutoDiffLambda;
    private final Function<PartialDerivative, Map<Vertex, PartialDerivative>> reverseModeAutoDiffLambda;

    /* JADX WARN: Multi-variable type inference failed */
    public DoubleUnaryOpLambda(long[] jArr, Vertex<IN> vertex, Function<IN, DoubleTensor> function, Function<Map<Vertex, PartialDerivative>, PartialDerivative> function2, Function<PartialDerivative, Map<Vertex, PartialDerivative>> function3) {
        super(jArr);
        this.inputVertex = vertex;
        this.op = function;
        this.forwardModeAutoDiffLambda = function2;
        this.reverseModeAutoDiffLambda = function3;
        setParents((Vertex<?>[]) new Vertex[]{vertex});
    }

    public DoubleUnaryOpLambda(long[] jArr, Vertex<IN> vertex, Function<IN, DoubleTensor> function) {
        this(jArr, vertex, function, null, null);
    }

    public DoubleUnaryOpLambda(Vertex<IN> vertex, Function<IN, DoubleTensor> function, Function<Map<Vertex, PartialDerivative>, PartialDerivative> function2, Function<PartialDerivative, Map<Vertex, PartialDerivative>> function3) {
        this(vertex.getShape(), vertex, function, function2, function3);
    }

    public DoubleUnaryOpLambda(Vertex<IN> vertex, Function<IN, DoubleTensor> function) {
        this(vertex.getShape(), vertex, function, null, null);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.vertices.NonProbabilistic
    public DoubleTensor calculate() {
        return (DoubleTensor) this.op.apply(this.inputVertex.getValue());
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> map) {
        if (this.forwardModeAutoDiffLambda != null) {
            return this.forwardModeAutoDiffLambda.apply(map);
        }
        throw new UnsupportedOperationException();
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative partialDerivative) {
        if (this.reverseModeAutoDiffLambda != null) {
            return this.reverseModeAutoDiffLambda.apply(partialDerivative);
        }
        throw new UnsupportedOperationException();
    }
}
