package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.ternary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.dbl.DoubleVertex;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/operators/ternary/DoubleSetWithMaskVertex.class */
public class DoubleSetWithMaskVertex extends DoubleVertex implements NonProbabilistic<DoubleTensor> {
    private static final String OPERAND_NAME = "operand";
    private static final String MASK_NAME = "mask";
    private static final String SET_VALUE_NAME = "setValue";
    private final DoubleVertex operand;
    private final DoubleVertex mask;
    private final DoubleVertex setValue;

    /* JADX WARN: Type inference failed for: r1v1, types: [long[], long[][]] */
    /* JADX WARN: Type inference failed for: r1v4, types: [long[], long[][]] */
    @ExportVertexToPythonBindings
    public DoubleSetWithMaskVertex(@LoadVertexParam("operand") DoubleVertex doubleVertex, @LoadVertexParam("mask") DoubleVertex doubleVertex2, @LoadVertexParam("setValue") DoubleVertex doubleVertex3) {
        super(TensorShapeValidation.checkAllShapesMatch((long[][]) new long[]{doubleVertex.getShape(), doubleVertex2.getShape()}));
        TensorShapeValidation.checkTensorsAreScalar("setValue must be scalar", new long[]{doubleVertex3.getShape()});
        this.operand = doubleVertex;
        this.mask = doubleVertex2;
        this.setValue = doubleVertex3;
        setParents(doubleVertex, doubleVertex2, doubleVertex3);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.vertices.NonProbabilistic
    public DoubleTensor calculate() {
        return (DoubleTensor) this.operand.getValue().setWithMask(this.mask.getValue(), (Number) this.setValue.getValue().scalar());
    }

    @SaveVertexParam(OPERAND_NAME)
    public DoubleVertex getOperand() {
        return this.operand;
    }

    @SaveVertexParam(MASK_NAME)
    public DoubleVertex getMask() {
        return this.mask;
    }

    @SaveVertexParam(SET_VALUE_NAME)
    public DoubleVertex getSetValue() {
        return this.setValue;
    }
}
