/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.tensor.functions.Concat;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;

public class ConcatReduce
extends IntermediateOperation {
    private static final String tmpDimensionName = "__concat_reduce_tmp_dimension_name__";
    private final Reduce.Aggregator aggregator;

    public ConcatReduce(String modelName, String nodeName, List<IntermediateOperation> inputs, Reduce.Aggregator aggregator) {
        super(modelName, nodeName, inputs);
        this.aggregator = aggregator;
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        if (!this.allInputTypesPresent(this.inputs.size())) {
            return null;
        }
        return ((IntermediateOperation)this.inputs.get(0)).type().get();
    }

    @Override
    protected TensorFunction lazyGetFunction() {
        if (!this.allInputFunctionsPresent(this.inputs.size())) {
            return null;
        }
        TensorFunction result = ((IntermediateOperation)this.inputs.get(0)).function().get();
        for (int i = 1; i < this.inputs.size(); ++i) {
            TensorFunction b = ((IntermediateOperation)this.inputs.get(i)).function().get();
            result = new Concat(result, b, tmpDimensionName);
        }
        return new Reduce(result, this.aggregator, tmpDimensionName);
    }

    @Override
    public void addDimensionNameConstraints(DimensionRenamer renamer) {
        if (!this.allInputTypesPresent(this.inputs.size())) {
            return;
        }
        OrderedTensorType a = ((IntermediateOperation)this.inputs.get(0)).type().get();
        for (int i = 1; i < this.inputs.size(); ++i) {
            OrderedTensorType b = ((IntermediateOperation)this.inputs.get(i)).type().get();
            OrderedTensorType largest = this.largestInput(a, b);
            OrderedTensorType smallest = this.smallestInput(a, b);
            int sizeDifference = largest.rank() - smallest.rank();
            for (int j = 0; j < smallest.rank(); ++j) {
                String bDim = smallest.dimensions().get(j).name();
                String aDim = largest.dimensions().get(j + sizeDifference).name();
                renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this);
            }
            a = b;
        }
    }

    private OrderedTensorType largestInput(OrderedTensorType a, OrderedTensorType b) {
        return a.rank() >= b.rank() ? a : b;
    }

    private OrderedTensorType smallestInput(OrderedTensorType a, OrderedTensorType b) {
        return a.rank() < b.rank() ? a : b;
    }

    @Override
    public ConcatReduce withInputs(List<IntermediateOperation> inputs) {
        return new ConcatReduce(this.modelName(), this.name(), inputs, this.aggregator);
    }

    @Override
    public String operationName() {
        return "ConcatReduce";
    }
}

