/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;

import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
import java.util.Optional;
import org.tensorflow.framework.NodeDef;

public class Switch
extends TensorFlowOperation {
    public Switch(NodeDef node, List<TensorFlowOperation> inputs, int port) {
        super(node, inputs, port);
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        if (!this.allInputTypesPresent(2)) {
            return null;
        }
        Optional<OrderedTensorType> predicate = ((TensorFlowOperation)this.inputs.get(1)).type();
        if (predicate.get().type().rank() != 0) {
            throw new IllegalArgumentException("Switch in " + this.node.getName() + ": predicate must be a scalar");
        }
        return ((TensorFlowOperation)this.inputs.get(0)).type().orElse(null);
    }

    @Override
    protected TensorFunction lazyGetFunction() {
        TensorFlowOperation predicateOperation = this.inputs().get(1);
        if (!predicateOperation.getConstantValue().isPresent()) {
            throw new IllegalArgumentException("Switch in " + this.node.getName() + ": predicate must be a constant");
        }
        if (this.port < 0 || this.port > 1) {
            throw new IllegalArgumentException("Switch in " + this.node.getName() + ": choice should be boolean");
        }
        double predicate = predicateOperation.getConstantValue().get().asDouble();
        return predicate == (double)this.port ? this.inputs().get(0).function().get() : null;
    }
}

