/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;

import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
import java.util.function.DoubleBinaryOperator;
import org.tensorflow.framework.NodeDef;

public class Select
extends TensorFlowOperation {
    public Select(NodeDef node, List<TensorFlowOperation> inputs, int port) {
        super(node, inputs, port);
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        if (!this.allInputTypesPresent(3)) {
            return null;
        }
        OrderedTensorType a = ((TensorFlowOperation)this.inputs.get(1)).type().get();
        OrderedTensorType b = ((TensorFlowOperation)this.inputs.get(2)).type().get();
        if (a.type().rank() != b.type().rank() || !TensorConverter.tensorSize(a.type()).equals(TensorConverter.tensorSize(b.type()))) {
            throw new IllegalArgumentException("'Select': input tensors must have the same shape");
        }
        return a;
    }

    @Override
    protected TensorFunction lazyGetFunction() {
        if (!this.allInputFunctionsPresent(3)) {
            return null;
        }
        TensorFlowOperation conditionOperation = this.inputs().get(0);
        TensorFunction a = this.inputs().get(1).function().get();
        TensorFunction b = this.inputs().get(2).function().get();
        if (conditionOperation.getConstantValue().isPresent()) {
            Tensor condition = conditionOperation.getConstantValue().get().asTensor();
            if (condition.type().rank() == 0) {
                return (int)condition.asDouble() == 0 ? b : a;
            }
            if (condition.type().rank() == 1 && TensorConverter.dimensionSize((TensorType.Dimension)condition.type().dimensions().get(0)) == 1L) {
                return ((Tensor.Cell)condition.cellIterator().next()).getValue().intValue() == 0 ? b : a;
            }
        }
        TensorFunction conditionFunction = conditionOperation.function().get();
        Join aCond = new Join(a, conditionFunction, ScalarFunctions.multiply());
        Join bCond = new Join(b, conditionFunction, new DoubleBinaryOperator(){

            @Override
            public double applyAsDouble(double a, double b) {
                return a * (1.0 - b);
            }

            public String toString() {
                return "f(a,b)(a * (1-b))";
            }
        });
        return new Join((TensorFunction)aCond, (TensorFunction)bCond, ScalarFunctions.add());
    }

    @Override
    public void addDimensionNameConstraints(DimensionRenamer renamer) {
        if (!this.allInputTypesPresent(3)) {
            return;
        }
        List<TensorType.Dimension> aDimensions = ((TensorFlowOperation)this.inputs.get(1)).type().get().dimensions();
        List<TensorType.Dimension> bDimensions = ((TensorFlowOperation)this.inputs.get(2)).type().get().dimensions();
        String aDim0 = aDimensions.get(0).name();
        String aDim1 = aDimensions.get(1).name();
        String bDim0 = bDimensions.get(0).name();
        String bDim1 = bDimensions.get(1).name();
        renamer.addConstraint(aDim0, bDim0, DimensionRenamer::equals, this);
        renamer.addConstraint(aDim1, bDim1, DimensionRenamer::equals, this);
    }
}

