/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.ternary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.dbl.DoubleVertex;

public class DoubleSetWithMaskVertex
extends DoubleVertex
implements NonProbabilistic<DoubleTensor> {
    private static final String OPERAND_NAME = "operand";
    private static final String MASK_NAME = "mask";
    private static final String SET_VALUE_NAME = "setValue";
    private final DoubleVertex operand;
    private final DoubleVertex mask;
    private final DoubleVertex setValue;

    @ExportVertexToPythonBindings
    public DoubleSetWithMaskVertex(@LoadVertexParam(value="operand") DoubleVertex operand, @LoadVertexParam(value="mask") DoubleVertex mask, @LoadVertexParam(value="setValue") DoubleVertex setValue) {
        super(TensorShapeValidation.checkAllShapesMatch(operand.getShape(), mask.getShape()));
        TensorShapeValidation.checkTensorsAreScalar("setValue must be scalar", new long[][]{setValue.getShape()});
        this.operand = operand;
        this.mask = mask;
        this.setValue = setValue;
        this.setParents(operand, mask, setValue);
    }

    @Override
    public DoubleTensor calculate() {
        return (DoubleTensor)((DoubleTensor)this.operand.getValue()).setWithMask((NumberTensor)this.mask.getValue(), (Number)((DoubleTensor)this.setValue.getValue()).scalar());
    }

    @SaveVertexParam(value="operand")
    public DoubleVertex getOperand() {
        return this.operand;
    }

    @SaveVertexParam(value="mask")
    public DoubleVertex getMask() {
        return this.mask;
    }

    @SaveVertexParam(value="setValue")
    public DoubleVertex getSetValue() {
        return this.setValue;
    }
}

