package io.improbable.keanu.vertices.generic.nonprobabilistic;

import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.generic.GenericTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.generic.GenericTensorVertex;

/* loaded from: input_file:io/improbable/keanu/vertices/generic/nonprobabilistic/IfVertex.class */
public class IfVertex<T> extends GenericTensorVertex<T> implements NonProbabilistic<GenericTensor<T>> {
    private static final String PREDICATE_NAME = "predicate";
    private static final String THEN_NAME = "then";
    private static final String ELSE_NAME = "else";
    private final BooleanVertex predicate;
    private final Vertex<GenericTensor<T>> thn;
    private final Vertex<GenericTensor<T>> els;

    public IfVertex(@LoadVertexParam("predicate") BooleanVertex booleanVertex, @LoadVertexParam("then") Vertex<GenericTensor<T>> vertex, @LoadVertexParam("else") Vertex<GenericTensor<T>> vertex2) {
        super(TensorShapeValidation.checkTernaryConditionShapeIsValid(booleanVertex.getShape(), vertex.getShape(), vertex2.getShape()));
        this.predicate = booleanVertex;
        this.thn = vertex;
        this.els = vertex2;
        setParents(booleanVertex, vertex, vertex2);
    }

    private GenericTensor<T> op(BooleanTensor booleanTensor, GenericTensor<T> genericTensor, GenericTensor<T> genericTensor2) {
        return (GenericTensor) booleanTensor.where(genericTensor, genericTensor2);
    }

    @Override // io.improbable.keanu.vertices.NonProbabilistic
    public GenericTensor<T> calculate() {
        return op(this.predicate.getValue(), this.thn.getValue(), this.els.getValue());
    }

    @SaveVertexParam(PREDICATE_NAME)
    public BooleanVertex getPredicate() {
        return this.predicate;
    }

    @SaveVertexParam(THEN_NAME)
    public Vertex<GenericTensor<T>> getThn() {
        return this.thn;
    }

    @SaveVertexParam(ELSE_NAME)
    public Vertex<GenericTensor<T>> getEls() {
        return this.els;
    }
}
