/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.modelintegration.evaluator;

import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.TensorInfo;
import ai.onnxruntime.ValueInfo;
import ai.onnxruntime.platform.Fp16Conversions;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

class TensorConverter {
    TensorConverter() {
    }

    static Map<String, OnnxTensor> toOnnxTensors(Map<String, Tensor> tensorMap, OrtEnvironment env, OrtSession session) throws OrtException {
        HashMap<String, OnnxTensor> result = new HashMap<String, OnnxTensor>();
        for (String name : tensorMap.keySet()) {
            Tensor vespaTensor = tensorMap.get(name);
            name = TensorConverter.toOnnxName(name, session.getInputInfo().keySet());
            TensorInfo onnxTensorInfo = TensorConverter.toTensorInfo(((NodeInfo)session.getInputInfo().get(name)).getInfo());
            OnnxTensor onnxTensor = TensorConverter.toOnnxTensor(vespaTensor, onnxTensorInfo, env);
            result.put(name, onnxTensor);
        }
        return result;
    }

    static OnnxTensor toOnnxTensor(Tensor vespaTensor, TensorInfo onnxTensorInfo, OrtEnvironment environment) throws OrtException {
        if (!(vespaTensor instanceof IndexedTensor)) {
            throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions");
        }
        IndexedTensor tensor = (IndexedTensor)vespaTensor;
        ByteBuffer buffer = ByteBuffer.allocateDirect((int)tensor.size() * onnxTensorInfo.type.size).order(ByteOrder.nativeOrder());
        if (onnxTensorInfo.type == OnnxJavaType.FLOAT) {
            int i = 0;
            while ((long)i < tensor.size()) {
                buffer.putFloat(tensor.getFloat((long)i));
                ++i;
            }
            return OnnxTensor.createTensor((OrtEnvironment)environment, (FloatBuffer)buffer.rewind().asFloatBuffer(), (long[])tensor.shape());
        }
        if (onnxTensorInfo.type == OnnxJavaType.DOUBLE) {
            int i = 0;
            while ((long)i < tensor.size()) {
                buffer.putDouble(tensor.get((long)i));
                ++i;
            }
            return OnnxTensor.createTensor((OrtEnvironment)environment, (DoubleBuffer)buffer.rewind().asDoubleBuffer(), (long[])tensor.shape());
        }
        if (onnxTensorInfo.type == OnnxJavaType.INT8) {
            int i = 0;
            while ((long)i < tensor.size()) {
                buffer.put((byte)tensor.get((long)i));
                ++i;
            }
            return OnnxTensor.createTensor((OrtEnvironment)environment, (ByteBuffer)buffer.rewind(), (long[])tensor.shape());
        }
        if (onnxTensorInfo.type == OnnxJavaType.INT16) {
            int i = 0;
            while ((long)i < tensor.size()) {
                buffer.putShort((short)tensor.get((long)i));
                ++i;
            }
            return OnnxTensor.createTensor((OrtEnvironment)environment, (ShortBuffer)buffer.rewind().asShortBuffer(), (long[])tensor.shape());
        }
        if (onnxTensorInfo.type == OnnxJavaType.INT32) {
            int i = 0;
            while ((long)i < tensor.size()) {
                buffer.putInt((int)tensor.get((long)i));
                ++i;
            }
            return OnnxTensor.createTensor((OrtEnvironment)environment, (IntBuffer)buffer.rewind().asIntBuffer(), (long[])tensor.shape());
        }
        if (onnxTensorInfo.type == OnnxJavaType.INT64) {
            int i = 0;
            while ((long)i < tensor.size()) {
                buffer.putLong((long)tensor.get((long)i));
                ++i;
            }
            return OnnxTensor.createTensor((OrtEnvironment)environment, (LongBuffer)buffer.rewind().asLongBuffer(), (long[])tensor.shape());
        }
        if (onnxTensorInfo.type == OnnxJavaType.FLOAT16) {
            int i = 0;
            while ((long)i < tensor.size()) {
                buffer.putShort(Fp16Conversions.floatToFp16((float)((float)tensor.get((long)i))));
                ++i;
            }
            return OnnxTensor.createTensor((OrtEnvironment)environment, (ByteBuffer)buffer.rewind(), (long[])tensor.shape(), (OnnxJavaType)OnnxJavaType.FLOAT16);
        }
        if (onnxTensorInfo.type == OnnxJavaType.BFLOAT16) {
            int i = 0;
            while ((long)i < tensor.size()) {
                buffer.putShort(Fp16Conversions.floatToBf16((float)((float)tensor.get((long)i))));
                ++i;
            }
            return OnnxTensor.createTensor((OrtEnvironment)environment, (ByteBuffer)buffer.rewind(), (long[])tensor.shape(), (OnnxJavaType)OnnxJavaType.BFLOAT16);
        }
        throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type);
    }

    private static void extractTensor(FloatBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) {
        for (int i = 0; i < totalSize; ++i) {
            builder.cellByDirectIndex((long)i, buffer.get(i));
        }
    }

    private static void extractTensor(DoubleBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) {
        for (int i = 0; i < totalSize; ++i) {
            builder.cellByDirectIndex((long)i, buffer.get(i));
        }
    }

    private static void extractTensor(ByteBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) {
        for (int i = 0; i < totalSize; ++i) {
            builder.cellByDirectIndex((long)i, (float)buffer.get(i));
        }
    }

    private static void extractTensor(ShortBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) {
        for (int i = 0; i < totalSize; ++i) {
            builder.cellByDirectIndex((long)i, (float)buffer.get(i));
        }
    }

    private static void extractTensor(ShortBuffer buffer, Short2Float converter, IndexedTensor.BoundBuilder builder, int totalSize) {
        for (int i = 0; i < totalSize; ++i) {
            builder.cellByDirectIndex((long)i, converter.convert(buffer.get(i)));
        }
    }

    private static void extractTensor(IntBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) {
        for (int i = 0; i < totalSize; ++i) {
            builder.cellByDirectIndex((long)i, (float)buffer.get(i));
        }
    }

    private static void extractTensor(LongBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) {
        for (int i = 0; i < totalSize; ++i) {
            builder.cellByDirectIndex((long)i, (float)buffer.get(i));
        }
    }

    static Tensor toVespaTensor(OnnxValue onnxValue) {
        if (!(onnxValue instanceof OnnxTensor)) {
            throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported");
        }
        OnnxTensor onnxTensor = (OnnxTensor)onnxValue;
        TensorInfo tensorInfo = onnxTensor.getInfo();
        TensorType type = TensorConverter.toVespaType((ValueInfo)onnxTensor.getInfo());
        DimensionSizes sizes = DimensionSizes.of((TensorType)type);
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of((TensorType)type, (DimensionSizes)sizes);
        long totalSizeAsLong = sizes.totalSize();
        if (totalSizeAsLong > Integer.MAX_VALUE) {
            throw new IllegalArgumentException("TotalSize=" + totalSizeAsLong + " currently limited at INTEGER.MAX_VALUE");
        }
        int totalSize = (int)totalSizeAsLong;
        switch (tensorInfo.type) {
            case FLOAT: {
                TensorConverter.extractTensor(onnxTensor.getFloatBuffer(), builder, totalSize);
                break;
            }
            case DOUBLE: {
                TensorConverter.extractTensor(onnxTensor.getDoubleBuffer(), builder, totalSize);
                break;
            }
            case INT8: {
                TensorConverter.extractTensor(onnxTensor.getByteBuffer(), builder, totalSize);
                break;
            }
            case INT16: {
                TensorConverter.extractTensor(onnxTensor.getShortBuffer(), builder, totalSize);
                break;
            }
            case INT32: {
                TensorConverter.extractTensor(onnxTensor.getIntBuffer(), builder, totalSize);
                break;
            }
            case INT64: {
                TensorConverter.extractTensor(onnxTensor.getLongBuffer(), builder, totalSize);
                break;
            }
            case FLOAT16: {
                TensorConverter.extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::fp16ToFloat, builder, totalSize);
                break;
            }
            case BFLOAT16: {
                TensorConverter.extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::bf16ToFloat, builder, totalSize);
                break;
            }
            default: {
                throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type);
            }
        }
        return builder.build();
    }

    static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) {
        return infoMap.entrySet().stream().collect(Collectors.toMap(e -> TensorConverter.asValidName((String)e.getKey()), e -> TensorConverter.toVespaType(((NodeInfo)e.getValue()).getInfo())));
    }

    static String asValidName(String name) {
        return OnnxImporter.asValidIdentifier(name);
    }

    static String toOnnxName(String name, Set<String> onnxNames) {
        if (onnxNames.contains(name)) {
            return name;
        }
        for (String onnxName : onnxNames) {
            if (!TensorConverter.asValidName(onnxName).equals(name)) continue;
            return onnxName;
        }
        throw new IllegalArgumentException("ONNX model has no input with name " + name);
    }

    static TensorType toVespaType(ValueInfo valueInfo) {
        TensorInfo tensorInfo = TensorConverter.toTensorInfo(valueInfo);
        TensorType.Builder builder = new TensorType.Builder(TensorConverter.toVespaValueType(tensorInfo.onnxType));
        long[] shape = tensorInfo.getShape();
        for (int i = 0; i < shape.length; ++i) {
            long dimSize = shape[i];
            String dimName = "d" + i;
            if (dimSize > 0L) {
                builder.indexed(dimName, dimSize);
                continue;
            }
            builder.indexed(dimName);
        }
        return builder.build();
    }

    private static TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) {
        return switch (onnxType) {
            case TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 -> TensorType.Value.INT8;
            case TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 -> TensorType.Value.BFLOAT16;
            case TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 -> TensorType.Value.FLOAT;
            case TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT -> TensorType.Value.FLOAT;
            case TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE -> TensorType.Value.DOUBLE;
            default -> TensorType.Value.DOUBLE;
        };
    }

    private static TensorInfo toTensorInfo(ValueInfo valueInfo) {
        if (!(valueInfo instanceof TensorInfo)) {
            throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported");
        }
        return (TensorInfo)valueInfo;
    }

    static interface Short2Float {
        public float convert(short var1);
    }
}

