/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

public class ConstantOfShape
extends IntermediateOperation {
    private final IntermediateOperation.AttributeMap attributeMap;
    private TensorType.Value valueTypeOfTensor = TensorType.Value.DOUBLE;
    private double valueToFillWith = 0.0;

    public ConstantOfShape(String modelName, String nodeName, List<IntermediateOperation> inputs, IntermediateOperation.AttributeMap attributeMap) {
        super(modelName, nodeName, inputs);
        this.attributeMap = attributeMap;
        Optional<Value> value = attributeMap.get("value");
        if (value.isPresent()) {
            Tensor t = value.get().asTensor();
            this.valueTypeOfTensor = t.type().valueType();
            this.valueToFillWith = (Double)t.valueIterator().next();
        }
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        if (!this.allInputTypesPresent(1)) {
            return null;
        }
        IntermediateOperation input = (IntermediateOperation)this.inputs.get(0);
        if (input.getConstantValue().isEmpty()) {
            throw new IllegalArgumentException("ConstantOfShape: 'shape' input must be a constant.");
        }
        Tensor shape = input.getConstantValue().get().asTensor();
        if (shape.type().dimensions().size() > 1) {
            throw new IllegalArgumentException("ConstantOfShape: 'shape' input must be a tensor with 0 or 1 dimensions.");
        }
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder(this.valueTypeOfTensor);
        Iterator iter = shape.valueIterator();
        int i = 0;
        while (iter.hasNext()) {
            builder.add(TensorType.Dimension.indexed((String)(this.vespaName() + "_" + i), (long)((Double)iter.next()).longValue()));
            ++i;
        }
        return builder.build();
    }

    @Override
    protected TensorFunction lazyGetFunction() {
        if (!this.allInputTypesPresent(1)) {
            return null;
        }
        ConstantNode valueExpr = new ConstantNode((Value)new DoubleValue(this.valueToFillWith));
        Generate function = Generate.bound((TensorType)this.type.type(), (ScalarFunction)TensorFunctionNode.wrapScalar((ExpressionNode)valueExpr));
        return function;
    }

    @Override
    public void addDimensionNameConstraints(DimensionRenamer renamer) {
        this.addConstraintsFrom(this.type, renamer);
    }

    @Override
    public ConstantOfShape withInputs(List<IntermediateOperation> inputs) {
        return new ConstantOfShape(this.modelName(), this.name(), inputs, this.attributeMap);
    }

    @Override
    public String operationName() {
        return "ConstantOfShape";
    }
}

