/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.generic.nonprobabilistic;

import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.generic.GenericTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.generic.GenericTensorVertex;

public class IfVertex<T>
extends GenericTensorVertex<T>
implements NonProbabilistic<GenericTensor<T>> {
    private static final String PREDICATE_NAME = "predicate";
    private static final String THEN_NAME = "then";
    private static final String ELSE_NAME = "else";
    private final BooleanVertex predicate;
    private final Vertex<GenericTensor<T>> thn;
    private final Vertex<GenericTensor<T>> els;

    public IfVertex(@LoadVertexParam(value="predicate") BooleanVertex predicate, @LoadVertexParam(value="then") Vertex<GenericTensor<T>> thn, @LoadVertexParam(value="else") Vertex<GenericTensor<T>> els) {
        super(TensorShapeValidation.checkTernaryConditionShapeIsValid(predicate.getShape(), thn.getShape(), els.getShape()));
        this.predicate = predicate;
        this.thn = thn;
        this.els = els;
        this.setParents(predicate, thn, els);
    }

    private GenericTensor<T> op(BooleanTensor predicate, GenericTensor<T> thn, GenericTensor<T> els) {
        return predicate.where(thn, els);
    }

    @Override
    public GenericTensor<T> calculate() {
        return this.op((BooleanTensor)this.predicate.getValue(), this.thn.getValue(), this.els.getValue());
    }

    @SaveVertexParam(value="predicate")
    public BooleanVertex getPredicate() {
        return this.predicate;
    }

    @SaveVertexParam(value="then")
    public Vertex<GenericTensor<T>> getThn() {
        return this.thn;
    }

    @SaveVertexParam(value="else")
    public Vertex<GenericTensor<T>> getEls() {
        return this.els;
    }
}

