/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.Slice;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

public class Expand
extends IntermediateOperation {
    public Expand(String modelName, String nodeName, List<IntermediateOperation> inputs) {
        super(modelName, nodeName, inputs);
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        int i;
        if (!this.allInputTypesPresent(2)) {
            return null;
        }
        ((IntermediateOperation)this.inputs.get((int)0)).exportAsRankingFunction = true;
        Optional<Value> shapeValue = ((IntermediateOperation)this.inputs.get(1)).getConstantValue();
        if (shapeValue.isEmpty()) {
            throw new IllegalArgumentException("Expand " + this.name + ": shape must be a constant.");
        }
        Tensor shape = shapeValue.get().asTensor();
        if (shape.type().rank() != 1) {
            throw new IllegalArgumentException("Expand " + this.name + ": shape must be a 1-d tensor.");
        }
        OrderedTensorType inputType = ((IntermediateOperation)this.inputs.get(0)).type().get();
        int inputRank = inputType.rank();
        int shapeSize = ((Long)((TensorType.Dimension)shape.type().dimensions().get(0)).size().get()).intValue();
        int sizeDiff = shapeSize - inputRank;
        OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(inputType.type().valueType());
        Iterator iter = shape.valueIterator();
        for (i = 0; i < sizeDiff; ++i) {
            typeBuilder.add(TensorType.Dimension.indexed((String)(this.vespaName() + "_" + i), (long)((Double)iter.next()).intValue()));
        }
        for (i = sizeDiff; i < shapeSize; ++i) {
            int inputDimSize;
            int shapeDimSize = ((Double)iter.next()).intValue();
            if (shapeDimSize != (inputDimSize = ((Long)inputType.dimensions().get(i - sizeDiff).size().get()).intValue()) && shapeDimSize != 1 && inputDimSize != 1) {
                throw new IllegalArgumentException("Expand " + this.name + ": dimension sizes of input and shape are not compatible. Either they must be equal or one must be of size 1.");
            }
            int dimSize = Math.max(shapeDimSize, inputDimSize);
            typeBuilder.add(TensorType.Dimension.indexed((String)(this.vespaName() + "_" + i), (long)dimSize));
        }
        return typeBuilder.build();
    }

    @Override
    protected TensorFunction lazyGetFunction() {
        int sizeDiff;
        if (!this.allInputFunctionsPresent(2)) {
            return null;
        }
        IntermediateOperation input = (IntermediateOperation)this.inputs.get(0);
        OrderedTensorType inputType = input.type().get();
        OrderedTensorType type = this.type().get();
        String inputFunctionName = input.rankingExpressionFunctionName();
        ArrayList<Slice.DimensionValue> dimensionValues = new ArrayList<Slice.DimensionValue>();
        for (int i = sizeDiff = this.type().get().rank() - inputType.rank(); i < this.type().get().rank(); ++i) {
            String inputDimensionName = inputType.dimensions().get(i - sizeDiff).name();
            String typeDimensionName = type.dimensionNames().get(i);
            long inputDimensionSize = (Long)inputType.dimensions().get(i - sizeDiff).size().get();
            Object index = inputDimensionSize == 1L ? new ConstantNode((Value)new DoubleValue(0.0)) : new EmbracedNode((ExpressionNode)new ReferenceNode(typeDimensionName));
            dimensionValues.add(new Slice.DimensionValue(Optional.of(inputDimensionName), TensorFunctionNode.wrapScalar((ExpressionNode)index)));
        }
        TensorFunctionNode.ExpressionTensorFunction externalRef = new TensorFunctionNode.ExpressionTensorFunction((ExpressionNode)new ReferenceNode(inputFunctionName));
        Slice sliceIndices = new Slice((TensorFunction)externalRef, dimensionValues);
        TensorFunctionNode sliceExpression = new TensorFunctionNode((TensorFunction)sliceIndices);
        return Generate.bound((TensorType)type.type(), (ScalarFunction)TensorFunctionNode.wrapScalar((ExpressionNode)sliceExpression));
    }

    @Override
    public void addDimensionNameConstraints(DimensionRenamer renamer) {
        this.addConstraintsFrom(this.type, renamer);
    }

    @Override
    public Expand withInputs(List<IntermediateOperation> inputs) {
        return new Expand(this.modelName(), this.name(), inputs);
    }

    @Override
    public String operationName() {
        return "Expand";
    }
}

