/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.tensorflow;

import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import java.util.List;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.TensorShapeProto;

class TypeConverter {
    TypeConverter() {
    }

    static void verifyType(NodeDef node, OrderedTensorType type) {
        TensorShapeProto shape = TypeConverter.tensorFlowShape(node);
        if (shape != null) {
            if (shape.getDimCount() != type.rank()) {
                throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' does not match Vespa shape");
            }
            for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) {
                int vespaIndex = type.dimensionMap(tensorFlowIndex);
                TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
                TensorType.Dimension vespaDimension = (TensorType.Dimension)type.type().dimensions().get(vespaIndex);
                if (tensorFlowDimension.getSize() == vespaDimension.size().orElse(-1L).longValue()) continue;
                throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' does not match Vespa dimensions");
            }
        }
    }

    private static TensorShapeProto tensorFlowShape(NodeDef node) {
        AttrValue attrValueList = (AttrValue)node.getAttrMap().get("_output_shapes");
        if (attrValueList == null) {
            throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' does not exist");
        }
        if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
            throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' is not of expected type");
        }
        List shapeList = attrValueList.getList().getShapeList();
        return (TensorShapeProto)shapeList.get(0);
    }

    static OrderedTensorType fromTensorFlowType(NodeDef node) {
        return TypeConverter.fromTensorFlowType(node, "d");
    }

    private static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
        TensorShapeProto shape = TypeConverter.tensorFlowShape(node);
        for (int i = 0; i < shape.getDimCount(); ++i) {
            String dimensionName = dimensionPrefix + i;
            TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
            if (tensorFlowDimension.getSize() >= 0L) {
                builder.add(TensorType.Dimension.indexed((String)dimensionName, (long)tensorFlowDimension.getSize()));
                continue;
            }
            builder.add(TensorType.Dimension.indexed((String)dimensionName));
        }
        return builder.build();
    }
}

