/*
 * Decompiled with CFR 0.152.
 */
package ai.onnxruntime;

import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxSparseTensor;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtUtil;
import ai.onnxruntime.ValueInfo;
import java.lang.reflect.Array;
import java.nio.Buffer;
import java.util.Arrays;

public class TensorInfo
implements ValueInfo {
    public static final int MAX_DIMENSIONS = 8;
    final long[] shape;
    public final OnnxJavaType type;
    public final OnnxTensorType onnxType;
    final long numElements;

    TensorInfo(long[] shape, OnnxJavaType type, OnnxTensorType onnxType) {
        this.shape = shape;
        this.type = type;
        this.onnxType = onnxType;
        this.numElements = TensorInfo.elementCount(shape);
    }

    TensorInfo(long[] shape, int typeInt) {
        this.shape = shape;
        this.onnxType = OnnxTensorType.mapFromInt(typeInt);
        this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
        this.numElements = TensorInfo.elementCount(shape);
    }

    public long[] getShape() {
        return Arrays.copyOf(this.shape, this.shape.length);
    }

    public String toString() {
        return "TensorInfo(javaType=" + this.type.toString() + ",onnxType=" + this.onnxType.toString() + ",shape=" + Arrays.toString(this.shape) + ")";
    }

    public boolean isScalar() {
        return this.shape.length == 0;
    }

    private boolean validateShape() {
        return OrtUtil.validateShape(this.shape);
    }

    private static long elementCount(long[] shape) {
        long output = 1L;
        for (int i = 0; i < shape.length; ++i) {
            output *= shape[i];
        }
        return output;
    }

    public long getNumElements() {
        return this.numElements;
    }

    public Object makeCarrier() throws OrtException {
        if (!this.validateShape() && this.numElements != 0L) {
            throw new OrtException("This tensor is not representable in Java, it's too big - shape = " + Arrays.toString(this.shape));
        }
        switch (this.type) {
            case FLOAT: {
                return OrtUtil.newFloatArray(this.shape);
            }
            case DOUBLE: {
                return OrtUtil.newDoubleArray(this.shape);
            }
            case INT8: 
            case UINT8: {
                return OrtUtil.newByteArray(this.shape);
            }
            case INT16: {
                return OrtUtil.newShortArray(this.shape);
            }
            case INT32: {
                return OrtUtil.newIntArray(this.shape);
            }
            case INT64: {
                return OrtUtil.newLongArray(this.shape);
            }
            case BOOL: {
                return OrtUtil.newBooleanArray(this.shape);
            }
            case STRING: {
                return new String[(int)OrtUtil.elementCount(this.shape)];
            }
            case UNKNOWN: {
                throw new OrtException("Can't construct a carrier for an invalid type.");
            }
        }
        throw new OrtException("Unsupported type - " + (Object)((Object)this.type));
    }

    public static TensorInfo constructFromJavaArray(Object obj) throws OrtException {
        Class<?> objClass = obj.getClass();
        if (!objClass.isArray()) {
            OnnxJavaType javaType = OnnxJavaType.mapFromClass(objClass);
            if (javaType == OnnxJavaType.UNKNOWN) {
                throw new OrtException("Cannot convert " + objClass + " to a OnnxTensor.");
            }
            return new TensorInfo(new long[0], javaType, OnnxTensorType.mapFromJavaType(javaType));
        }
        int dimensions = 0;
        while (objClass.isArray()) {
            objClass = objClass.getComponentType();
            ++dimensions;
        }
        if (!objClass.isPrimitive() && !objClass.equals(String.class)) {
            throw new OrtException("Cannot create an OnnxTensor from a base type of " + objClass);
        }
        if (dimensions > 8) {
            throw new OrtException("Cannot create an OnnxTensor with more than 8 dimensions. Found " + dimensions + " dimensions.");
        }
        OnnxJavaType javaType = OnnxJavaType.mapFromClass(objClass);
        long[] shape = new long[dimensions];
        TensorInfo.extractShape(shape, 0, obj);
        return new TensorInfo(shape, javaType, OnnxTensorType.mapFromJavaType(javaType));
    }

    public static TensorInfo constructFromBuffer(Buffer buffer, long[] shape, OnnxJavaType type) throws OrtException {
        long bufferRemaining;
        if (type == OnnxJavaType.STRING || type == OnnxJavaType.UNKNOWN) {
            throw new OrtException("Cannot create a tensor from a string or unknown buffer.");
        }
        long elementCount = OrtUtil.elementCount(shape);
        if (elementCount != (bufferRemaining = (long)buffer.remaining())) {
            throw new OrtException("Shape " + Arrays.toString(shape) + ", requires " + elementCount + " elements but the buffer has " + bufferRemaining + " elements.");
        }
        return new TensorInfo(Arrays.copyOf(shape, shape.length), type, OnnxTensorType.mapFromJavaType(type));
    }

    public static <T extends Buffer> TensorInfo constructFromSparseTensor(OnnxSparseTensor.SparseTensor<T> tensor) throws OrtException {
        long bufferRemaining;
        long[] shape = tensor.getDenseShape();
        long elementCount = OrtUtil.elementCount(shape);
        if (elementCount < (bufferRemaining = (long)tensor.getValues().remaining())) {
            throw new OrtException("Shape " + Arrays.toString(shape) + ", has at most " + elementCount + " elements but the buffer has " + bufferRemaining + " elements.");
        }
        return new TensorInfo(Arrays.copyOf(shape, shape.length), tensor.getType(), OnnxTensorType.mapFromJavaType(tensor.getType()));
    }

    private static void extractShape(long[] shape, int curDim, Object obj) throws OrtException {
        if (shape.length != curDim) {
            int curLength = Array.getLength(obj);
            if (curLength == 0) {
                throw new OrtException("Supplied array has a zero dimension at " + curDim + ", all dimensions must be positive");
            }
            if (shape[curDim] == 0L) {
                shape[curDim] = curLength;
            } else if (shape[curDim] != (long)curLength) {
                throw new OrtException("Supplied array is ragged, expected " + shape[curDim] + ", found " + curLength);
            }
            for (int i = 0; i < curLength; ++i) {
                TensorInfo.extractShape(shape, curDim + 1, Array.get(obj, i));
            }
        }
    }

    public static enum OnnxTensorType {
        ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED(0),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8(1),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8(2),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16(3),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16(4),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32(5),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32(6),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64(7),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64(8),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16(9),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT(10),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE(11),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING(12),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL(13),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64(14),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128(15),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16(16);

        public final int value;
        private static final OnnxTensorType[] values;

        private OnnxTensorType(int value) {
            this.value = value;
        }

        public static OnnxTensorType mapFromInt(int value) {
            if (value > 0 && value < values.length) {
                return values[value];
            }
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
        }

        public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
            switch (type) {
                case FLOAT: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
                }
                case DOUBLE: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
                }
                case INT8: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
                }
                case UINT8: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
                }
                case INT16: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
                }
                case INT32: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
                }
                case INT64: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
                }
                case BOOL: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
                }
                case STRING: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
                }
            }
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
        }

        static {
            values = new OnnxTensorType[17];
            OnnxTensorType[] onnxTensorTypeArray = OnnxTensorType.values();
            int n = onnxTensorTypeArray.length;
            for (int i = 0; i < n; ++i) {
                OnnxTensorType ot;
                OnnxTensorType.values[ot.value] = ot = onnxTensorTypeArray[i];
            }
        }
    }
}

