/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;

import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
import java.util.Optional;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.NodeDef;

public class Const
extends TensorFlowOperation {
    public Const(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
        super(modelName, node, inputs, port);
        this.setConstantValue(this.value());
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        return OrderedTensorType.fromTensorFlowType(this.node, this.vespaName() + "_");
    }

    @Override
    public Optional<TensorFunction> function() {
        if (this.function == null) {
            this.function = this.lazyGetFunction();
        }
        return Optional.ofNullable(this.function);
    }

    @Override
    protected TensorFunction lazyGetFunction() {
        ExpressionNode expressionNode = this.type.type().rank() == 0 && this.getConstantValue().isPresent() ? new ConstantNode(this.getConstantValue().get().asDoubleValue()) : new ReferenceNode(Reference.simple("constant", this.vespaName()));
        return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode);
    }

    @Override
    public String vespaName() {
        return this.modelName() + "_" + super.vespaName();
    }

    @Override
    public void addDimensionNameConstraints(DimensionRenamer renamer) {
        for (TensorType.Dimension dimension : this.type.type().dimensions()) {
            renamer.addDimension(dimension.name());
        }
    }

    @Override
    public void renameDimensions(DimensionRenamer renamer) {
        super.renameDimensions(renamer);
        this.setConstantValue(this.value());
    }

    @Override
    public boolean isConstant() {
        return true;
    }

    private Value value() {
        if (!this.node.getAttrMap().containsKey("value")) {
            throw new IllegalArgumentException("Node '" + this.node.getName() + "' of type const has missing 'value' attribute");
        }
        AttrValue attrValue = (AttrValue)this.node.getAttrMap().get("value");
        if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
            return new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), this.type().get().type()));
        }
        if (attrValue.getValueCase() == AttrValue.ValueCase.B) {
            return new BooleanValue(attrValue.getB());
        }
        if (attrValue.getValueCase() == AttrValue.ValueCase.I) {
            return new DoubleValue(attrValue.getI());
        }
        if (attrValue.getValueCase() == AttrValue.ValueCase.F) {
            return new DoubleValue(attrValue.getF());
        }
        throw new IllegalArgumentException("Requesting value of constant in " + this.node.getName() + " but type is not recognized.");
    }
}

