/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;

import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
import org.tensorflow.framework.NodeDef;

public class Join
extends TensorFlowOperation {
    private final DoubleBinaryOperator operator;

    public Join(NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleBinaryOperator operator) {
        super(node, inputs, port);
        this.operator = operator;
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        if (!this.allInputTypesPresent(2)) {
            return null;
        }
        OrderedTensorType a = ((TensorFlowOperation)this.inputs.get(0)).type().get();
        OrderedTensorType b = ((TensorFlowOperation)this.inputs.get(1)).type().get();
        OrderedTensorType out = a.type().rank() >= b.type().rank() ? a : b;
        return out;
    }

    @Override
    protected TensorFunction lazyGetFunction() {
        if (!this.allInputTypesPresent(2)) {
            return null;
        }
        Optional<TensorFunction> aFunction = ((TensorFlowOperation)this.inputs.get(0)).function();
        Optional<TensorFunction> bFunction = ((TensorFlowOperation)this.inputs.get(1)).function();
        if (!aFunction.isPresent() || !bFunction.isPresent()) {
            return null;
        }
        return new com.yahoo.tensor.functions.Join(aFunction.get(), bFunction.get(), this.operator);
    }

    @Override
    public void addDimensionNameConstraints(DimensionRenamer renamer) {
        if (!this.allInputTypesPresent(2)) {
            return;
        }
        OrderedTensorType a = ((TensorFlowOperation)this.inputs.get(0)).type().get();
        OrderedTensorType b = ((TensorFlowOperation)this.inputs.get(1)).type().get();
        if (a.rank() < b.rank()) {
            OrderedTensorType temp = a;
            a = b;
            b = temp;
        }
        int sizeDifference = a.rank() - b.rank();
        for (int i = 0; i < b.rank(); ++i) {
            String bDim = b.dimensions().get(i).name();
            String aDim = a.dimensions().get(i + sizeDifference).name();
            renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
        }
    }
}

