/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
import java.util.function.DoubleBinaryOperator;

public class Select
extends IntermediateOperation {
    public Select(String modelName, String nodeName, List<IntermediateOperation> inputs) {
        super(modelName, nodeName, inputs);
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        if (!this.allInputTypesPresent(3)) {
            return null;
        }
        OrderedTensorType a = ((IntermediateOperation)this.inputs.get(1)).type().get();
        OrderedTensorType b = ((IntermediateOperation)this.inputs.get(2)).type().get();
        if (a.type().rank() != b.type().rank() || !OrderedTensorType.tensorSize(a.type()).equals(OrderedTensorType.tensorSize(b.type()))) {
            throw new IllegalArgumentException("'Select': input tensors must have the same shape");
        }
        return a;
    }

    @Override
    protected TensorFunction<Reference> lazyGetFunction() {
        if (!this.allInputFunctionsPresent(3)) {
            return null;
        }
        IntermediateOperation conditionOperation = this.inputs().get(0);
        TensorFunction<Reference> a = this.inputs().get(1).function().get();
        TensorFunction<Reference> b = this.inputs().get(2).function().get();
        if (conditionOperation.getConstantValue().isPresent()) {
            Tensor condition = conditionOperation.getConstantValue().get().asTensor();
            if (condition.type().rank() == 0) {
                return (int)condition.asDouble() == 0 ? b : a;
            }
            if (condition.type().rank() == 1 && OrderedTensorType.dimensionSize((TensorType.Dimension)condition.type().dimensions().get(0)) == 1L) {
                return ((Tensor.Cell)condition.cellIterator().next()).getValue().intValue() == 0 ? b : a;
            }
            if (condition.type().rank() == 2 && OrderedTensorType.dimensionSize((TensorType.Dimension)condition.type().dimensions().get(0)) == 1L && OrderedTensorType.dimensionSize((TensorType.Dimension)condition.type().dimensions().get(1)) == 1L) {
                return ((Tensor.Cell)condition.cellIterator().next()).getValue().intValue() == 0 ? b : a;
            }
        }
        TensorFunction<Reference> conditionFunction = conditionOperation.function().get();
        Join aCond = new Join(a, conditionFunction, ScalarFunctions.multiply());
        Join bCond = new Join(b, conditionFunction, new DoubleBinaryOperator(){

            @Override
            public double applyAsDouble(double a, double b) {
                return a * (1.0 - b);
            }

            public String toString() {
                return "f(a,b)(a * (1-b))";
            }
        });
        return new Join((TensorFunction)aCond, (TensorFunction)bCond, ScalarFunctions.add());
    }

    @Override
    public void addDimensionNameConstraints(DimensionRenamer renamer) {
        if (!this.allInputTypesPresent(3)) {
            return;
        }
        List<TensorType.Dimension> aDimensions = ((IntermediateOperation)this.inputs.get(1)).type().get().dimensions();
        List<TensorType.Dimension> bDimensions = ((IntermediateOperation)this.inputs.get(2)).type().get().dimensions();
        for (int i = 0; i < aDimensions.size(); ++i) {
            String aDim = aDimensions.get(i).name();
            String bDim = bDimensions.get(i).name();
            renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this);
        }
    }

    @Override
    public Select withInputs(List<IntermediateOperation> inputs) {
        return new Select(this.modelName(), this.name(), inputs);
    }

    @Override
    public String operationName() {
        return "Select";
    }
}

