/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.onnx;

import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import onnx.Onnx;

class TypeConverter {
    TypeConverter() {
    }

    static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) {
        Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
        if (shape != null) {
            if (shape.getDimCount() != type.rank()) {
                throw new IllegalArgumentException("Onnx shape of does not match Vespa shape");
            }
            for (int onnxIndex = 0; onnxIndex < type.dimensions().size(); ++onnxIndex) {
                TensorType.Dimension vespaDimension;
                long onnxDimensionSize;
                int vespaIndex = type.dimensionMap(onnxIndex);
                Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
                long l = onnxDimensionSize = onnxDimension.getDimValue() == 0L ? 1L : onnxDimension.getDimValue();
                if (onnxDimensionSize == -1L || onnxDimensionSize == (vespaDimension = (TensorType.Dimension)type.type().dimensions().get(vespaIndex)).size().orElse(-1L)) continue;
                throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions");
            }
        }
    }

    static OrderedTensorType typeFrom(Onnx.TypeProto type) {
        String dimensionPrefix = "d";
        Onnx.TensorShapeProto shape = type.getTensorType().getShape();
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder(TypeConverter.toValueType(type.getTensorType().getElemType()));
        for (int i = 0; i < shape.getDimCount(); ++i) {
            long onnxDimensionSize;
            String dimensionName = dimensionPrefix + i;
            Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
            long l = onnxDimensionSize = onnxDimension.getDimValue() == 0L ? 1L : onnxDimension.getDimValue();
            if (onnxDimensionSize >= 0L) {
                builder.add(TensorType.Dimension.indexed((String)dimensionName, (long)onnxDimensionSize));
                continue;
            }
            builder.add(TensorType.Dimension.indexed((String)dimensionName));
        }
        return builder.build();
    }

    static OrderedTensorType typeFrom(Onnx.TensorProto tensor) {
        return OrderedTensorType.fromDimensionList(TypeConverter.toValueType(tensor.getDataType()), tensor.getDimsList());
    }

    private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
        switch (dataType) {
            case FLOAT: {
                return TensorType.Value.FLOAT;
            }
            case DOUBLE: {
                return TensorType.Value.DOUBLE;
            }
            case BOOL: {
                return TensorType.Value.FLOAT;
            }
            case INT8: {
                return TensorType.Value.FLOAT;
            }
            case INT16: {
                return TensorType.Value.FLOAT;
            }
            case INT32: {
                return TensorType.Value.DOUBLE;
            }
            case INT64: {
                return TensorType.Value.DOUBLE;
            }
            case UINT8: {
                return TensorType.Value.FLOAT;
            }
            case UINT16: {
                return TensorType.Value.FLOAT;
            }
            case UINT32: {
                return TensorType.Value.DOUBLE;
            }
            case UINT64: {
                return TensorType.Value.DOUBLE;
            }
        }
        throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + " cannot be converted to a Vespa tensor type");
    }
}

