/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;

import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
import java.util.Optional;

public class Switch
extends IntermediateOperation {
    private final int port;

    public Switch(String modelName, String nodeName, List<IntermediateOperation> inputs, int port) {
        super(modelName, nodeName, inputs);
        this.port = port;
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        if (!this.allInputTypesPresent(2)) {
            return null;
        }
        Optional<OrderedTensorType> predicate = ((IntermediateOperation)this.inputs.get(1)).type();
        if (predicate.get().type().rank() != 0) {
            throw new IllegalArgumentException("Switch in " + this.name + ": predicate must be a scalar");
        }
        return ((IntermediateOperation)this.inputs.get(0)).type().orElse(null);
    }

    @Override
    protected TensorFunction lazyGetFunction() {
        IntermediateOperation predicateOperation = this.inputs().get(1);
        if (!predicateOperation.getConstantValue().isPresent()) {
            throw new IllegalArgumentException("Switch in " + this.name + ": predicate must be a constant");
        }
        if (this.port < 0 || this.port > 1) {
            throw new IllegalArgumentException("Switch in " + this.name + ": choice should be boolean");
        }
        double predicate = predicateOperation.getConstantValue().get().asDouble();
        return predicate == (double)this.port ? this.inputs().get(0).function().get() : null;
    }
}

