/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.tensorflow.conversion;

import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionDescriptor;

public enum TensorDataType {
    INVALID,
    FLOAT,
    DOUBLE,
    INT32,
    UINT8,
    INT16,
    INT8,
    STRING,
    COMPLEX64,
    INT64,
    BOOL,
    QINT8,
    QUINT8,
    QINT32,
    BFLOAT16,
    QINT16,
    QUINT16,
    UINT16,
    COMPLEX128,
    HALF,
    RESOURCE,
    VARIANT,
    UINT32,
    UINT64;


    public static TensorDataType fromProtoValue(String value) {
        String valueReplace = value.replace("DT_", "");
        return TensorDataType.valueOf(valueReplace);
    }

    public static String toPythonName(TensorDataType tensorDataType) {
        switch (tensorDataType) {
            case DOUBLE: {
                return "float64";
            }
            case FLOAT: {
                return "float32";
            }
            case HALF: {
                return "float16";
            }
        }
        return tensorDataType.name().toLowerCase();
    }

    public static DataType toNd4jType(TensorDataType tensorDataType) {
        switch (tensorDataType) {
            case FLOAT: {
                return DataType.FLOAT;
            }
            case DOUBLE: {
                return DataType.DOUBLE;
            }
            case BOOL: {
                return DataType.BOOL;
            }
            case INT32: {
                return DataType.INT;
            }
            case INT64: {
                return DataType.LONG;
            }
            case STRING: {
                return DataType.UTF8;
            }
            case HALF: {
                return DataType.HALF;
            }
        }
        throw new IllegalArgumentException("Unsupported type " + tensorDataType.name());
    }

    public static TensorDataType fromNd4jType(DataType dataType) {
        switch (dataType) {
            case FLOAT: {
                return FLOAT;
            }
            case LONG: {
                return INT64;
            }
            case INT: {
                return INT32;
            }
            case BOOL: {
                return BOOL;
            }
            case DOUBLE: {
                return DOUBLE;
            }
            case HALF: {
                return HALF;
            }
            case UTF8: {
                return STRING;
            }
            case COMPRESSED: {
                throw new IllegalStateException("Unable to work with compressed data type. Could be 1 or more types.");
            }
            case SHORT: {
                return INT16;
            }
        }
        throw new IllegalArgumentException("Unknown data type " + dataType);
    }

    public static TensorDataType fromNd4jType(INDArray array) {
        DataType dataType = array.dataType();
        switch (dataType) {
            case COMPRESSED: {
                String algo;
                CompressedDataBuffer compressedData = (CompressedDataBuffer)array.data();
                CompressionDescriptor desc = compressedData.getCompressionDescriptor();
                switch (algo = desc.getCompressionAlgorithm()) {
                    case "FLOAT16": {
                        return HALF;
                    }
                    case "INT8": {
                        return INT8;
                    }
                    case "UINT8": {
                        return UINT8;
                    }
                    case "INT16": {
                        return INT16;
                    }
                    case "UINT16": {
                        return UINT16;
                    }
                }
                throw new IllegalArgumentException("Unsupported compression algorithm: " + algo);
            }
        }
        return TensorDataType.fromNd4jType(dataType);
    }
}

