/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.tensorflow.conversion;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Map;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.bytedeco.tensorflow.Deallocator_Pointer_long_Pointer;
import org.bytedeco.tensorflow.TF_Buffer;
import org.bytedeco.tensorflow.TF_Graph;
import org.bytedeco.tensorflow.TF_ImportGraphDefOptions;
import org.bytedeco.tensorflow.TF_Session;
import org.bytedeco.tensorflow.TF_SessionOptions;
import org.bytedeco.tensorflow.TF_Status;
import org.bytedeco.tensorflow.TF_Tensor;
import org.bytedeco.tensorflow.global.tensorflow;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
import org.nd4j.tensorflow.conversion.DummyDeAllocator;
import org.nd4j.tensorflow.conversion.graphrunner.SavedModelConfig;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;

public class TensorflowConversion {
    private static Deallocator_Pointer_long_Pointer calling;
    private static TensorflowConversion INSTANCE;

    public static TensorflowConversion getInstance() {
        if (INSTANCE == null) {
            INSTANCE = new TensorflowConversion();
        }
        return INSTANCE;
    }

    private TensorflowConversion() {
        if (calling == null) {
            calling = DummyDeAllocator.getInstance();
        }
    }

    public TF_Tensor tensorFromNDArray(INDArray ndArray) {
        int type;
        if (ndArray == null) {
            throw new IllegalArgumentException("NDArray must not be null!");
        }
        if (ndArray.data() == null) {
            throw new IllegalArgumentException("Unable to infer data type from null databuffer");
        }
        if (ndArray.isView() || ndArray.ordering() != 'c') {
            ndArray = ndArray.dup('c');
        }
        long[] ndShape = ndArray.shape();
        long[] tfShape = new long[ndShape.length];
        System.arraycopy(ndShape, 0, tfShape, 0, ndShape.length);
        DataBuffer data = ndArray.data();
        DataType dataType = data.dataType();
        block1 : switch (dataType) {
            case DOUBLE: {
                type = 2;
                break;
            }
            case FLOAT: {
                type = 1;
                break;
            }
            case INT: {
                type = 3;
                break;
            }
            case HALF: {
                type = 19;
                break;
            }
            case COMPRESSED: {
                String algo;
                CompressedDataBuffer compressedData = (CompressedDataBuffer)data;
                CompressionDescriptor desc = compressedData.getCompressionDescriptor();
                switch (algo = desc.getCompressionAlgorithm()) {
                    case "FLOAT16": {
                        type = 19;
                        break block1;
                    }
                    case "INT8": {
                        type = 6;
                        break block1;
                    }
                    case "UINT8": {
                        type = 4;
                        break block1;
                    }
                    case "INT16": {
                        type = 5;
                        break block1;
                    }
                    case "UINT16": {
                        type = 17;
                        break block1;
                    }
                }
                throw new IllegalArgumentException("Unsupported compression algorithm: " + algo);
            }
            case LONG: {
                type = 9;
                break;
            }
            case UTF8: {
                type = 7;
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported data type: " + dataType);
            }
        }
        try {
            Nd4j.getAffinityManager().ensureLocation(ndArray, AffinityManager.Location.HOST);
        }
        catch (Exception e) {
            ndArray.getDouble(0L);
            data = ndArray.data();
            dataType = data.dataType();
            switch (dataType) {
                case DOUBLE: {
                    type = 2;
                    break;
                }
                case FLOAT: {
                    type = 1;
                    break;
                }
                case INT: {
                    type = 3;
                    break;
                }
                case LONG: {
                    type = 9;
                    break;
                }
                case UTF8: {
                    type = 7;
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Unsupported data type: " + dataType);
                }
            }
        }
        LongPointer longPointer = new LongPointer(tfShape);
        TF_Tensor tf_tensor = null;
        if (type == 7) {
            long size = 0L;
            long length = ndArray.length();
            BytePointer[] strings = new BytePointer[(int)length];
            int i = 0;
            while ((long)i < length) {
                strings[i] = new BytePointer(ndArray.getString((long)i));
                size += tensorflow.TF_StringEncodedSize((long)strings[i].capacity());
                ++i;
            }
            tf_tensor = tensorflow.TF_AllocateTensor((int)type, (LongPointer)longPointer, (int)tfShape.length, (long)(8L * length + size));
            long offset = 0L;
            BytePointer tf_data = new BytePointer(tensorflow.TF_TensorData((TF_Tensor)tf_tensor)).capacity(tensorflow.TF_TensorByteSize((TF_Tensor)tf_tensor));
            TF_Status status = tensorflow.TF_NewStatus();
            int i2 = 0;
            while ((long)i2 < length) {
                tf_data.position((long)(8 * i2)).putLong(offset);
                offset += tensorflow.TF_StringEncode((BytePointer)strings[i2], (long)(strings[i2].capacity() - 1L), (BytePointer)tf_data.position(8L * length + offset), (long)(tf_data.capacity() - tf_data.position()), (TF_Status)status);
                if (tensorflow.TF_GetCode((TF_Status)status) != 0) {
                    throw new IllegalStateException("ERROR: Unable to convert tensor " + tensorflow.TF_Message((TF_Status)status).getString());
                }
                ++i2;
            }
            tensorflow.TF_DeleteStatus((TF_Status)status);
        } else {
            tf_tensor = tensorflow.TF_NewTensor((int)type, (LongPointer)longPointer, (int)tfShape.length, (Pointer)data.pointer(), (long)(data.length() * (long)data.getElementSize()), (Deallocator_Pointer_long_Pointer)calling, null);
        }
        return tf_tensor;
    }

    public INDArray ndArrayFromTensor(TF_Tensor tensor) {
        INDArray array;
        int[] ndShape;
        int rank = tensorflow.TF_NumDims((TF_Tensor)tensor);
        if (rank == 0) {
            ndShape = new int[]{1};
        } else {
            ndShape = new int[rank];
            for (int i = 0; i < ndShape.length; ++i) {
                ndShape[i] = (int)tensorflow.TF_Dim((TF_Tensor)tensor, (int)i);
            }
        }
        int tfType = tensorflow.TF_TensorType((TF_Tensor)tensor);
        DataType nd4jType = this.typeFor(tfType);
        int length = ArrayUtil.prod((int[])ndShape);
        if (nd4jType == DataType.UTF8) {
            String[] strings = new String[length];
            BytePointer data = new BytePointer(tensorflow.TF_TensorData((TF_Tensor)tensor)).capacity(tensorflow.TF_TensorByteSize((TF_Tensor)tensor));
            BytePointer str = new BytePointer((Pointer)null);
            SizeTPointer size = new SizeTPointer(1L);
            TF_Status status = tensorflow.TF_NewStatus();
            for (int i = 0; i < length; ++i) {
                long offset = data.position((long)(8 * i)).getLong();
                tensorflow.TF_StringDecode((BytePointer)data.position((long)(8 * length) + offset), (long)(data.capacity() - data.position()), (BytePointer)str, (SizeTPointer)size, (TF_Status)status);
                if (tensorflow.TF_GetCode((TF_Status)status) != 0) {
                    throw new IllegalStateException("ERROR: Unable to convert tensor " + tensorflow.TF_Message((TF_Status)status).getString());
                }
                strings[i] = str.position(0L).capacity(size.get()).getString();
            }
            tensorflow.TF_DeleteStatus((TF_Status)status);
            array = Nd4j.create((String[])strings);
        } else {
            Pointer pointer = tensorflow.TF_TensorData((TF_Tensor)tensor).capacity((long)length);
            Indexer indexer = this.indexerForType(nd4jType, pointer);
            DataBuffer d = Nd4j.createBuffer((Pointer)indexer.pointer(), (DataType)nd4jType, (long)length, (Indexer)indexer);
            array = Nd4j.create((DataBuffer)d, (int[])ndShape);
        }
        return array;
    }

    private Indexer indexerForType(DataType type, Pointer pointer) {
        switch (type) {
            case DOUBLE: {
                return DoubleIndexer.create((DoublePointer)new DoublePointer(pointer));
            }
            case FLOAT: {
                return FloatIndexer.create((FloatPointer)new FloatPointer(pointer));
            }
            case INT: {
                return IntIndexer.create((IntPointer)new IntPointer(pointer));
            }
            case LONG: {
                return LongIndexer.create((LongPointer)new LongPointer(pointer));
            }
        }
        throw new IllegalArgumentException("Illegal type " + type);
    }

    private DataType typeFor(int tensorflowType) {
        switch (tensorflowType) {
            case 2: {
                return DataType.DOUBLE;
            }
            case 1: {
                return DataType.FLOAT;
            }
            case 3: {
                return DataType.LONG;
            }
            case 9: {
                return DataType.LONG;
            }
            case 7: {
                return DataType.UTF8;
            }
        }
        throw new IllegalArgumentException("Illegal type " + tensorflowType);
    }

    public TF_Graph loadGraph(String filePath, TF_Status status) throws IOException {
        byte[] bytes = Files.readAllBytes(Paths.get(filePath, new String[0]));
        return this.loadGraph(bytes, status);
    }

    public static String defaultDeviceForThread() {
        Integer deviceForThread = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        String deviceName = null;
        deviceName = Nd4j.getBackend().getClass().getName().contains("JCublasBackend") ? "/device:gpu:" + deviceForThread : "/device:cpu:" + deviceForThread;
        return deviceName;
    }

    public TF_Graph loadGraph(byte[] content, TF_Status status) {
        byte[] toLoad = content;
        TF_Buffer graph_def = tensorflow.TF_NewBufferFromString((Pointer)new BytePointer(toLoad), (long)content.length);
        TF_Graph graphC = tensorflow.TF_NewGraph();
        TF_ImportGraphDefOptions opts = tensorflow.TF_NewImportGraphDefOptions();
        tensorflow.TF_GraphImportGraphDef((TF_Graph)graphC, (TF_Buffer)graph_def, (TF_ImportGraphDefOptions)opts, (TF_Status)status);
        if (tensorflow.TF_GetCode((TF_Status)status) != 0) {
            throw new IllegalStateException("ERROR: Unable to import graph " + tensorflow.TF_Message((TF_Status)status).getString());
        }
        tensorflow.TF_DeleteImportGraphDefOptions((TF_ImportGraphDefOptions)opts);
        return graphC;
    }

    public TF_Session loadSavedModel(SavedModelConfig savedModelConfig, TF_SessionOptions options, TF_Buffer runOptions, TF_Graph graph, Map<String, String> inputsMap, Map<String, String> outputsMap, TF_Status status) {
        MetaGraphDef metaGraphDef;
        TF_Buffer metaGraph = TF_Buffer.newBuffer();
        TF_Session session = tensorflow.TF_LoadSessionFromSavedModel((TF_SessionOptions)options, (TF_Buffer)runOptions, (BytePointer)new BytePointer(savedModelConfig.getSavedModelPath()), (BytePointer)new BytePointer(savedModelConfig.getModelTag()), (int)1, (TF_Graph)graph, (TF_Buffer)metaGraph, (TF_Status)status);
        if (tensorflow.TF_GetCode((TF_Status)status) != 0) {
            throw new IllegalStateException("ERROR: Unable to import model " + tensorflow.TF_Message((TF_Status)status).getString());
        }
        try {
            metaGraphDef = MetaGraphDef.parseFrom((ByteBuffer)metaGraph.data().capacity(metaGraph.length()).asByteBuffer());
        }
        catch (InvalidProtocolBufferException ex) {
            throw new IllegalStateException("ERROR: Unable to import model " + (Object)((Object)ex));
        }
        Map signatureDefMap = metaGraphDef.getSignatureDefMap();
        SignatureDef signatureDef = (SignatureDef)signatureDefMap.get(savedModelConfig.getSignatureKey());
        Map inputs = signatureDef.getInputsMap();
        for (Map.Entry e : inputs.entrySet()) {
            inputsMap.put((String)e.getKey(), ((TensorInfo)e.getValue()).getName());
        }
        Map outputs = signatureDef.getOutputsMap();
        for (Map.Entry e : outputs.entrySet()) {
            outputsMap.put((String)e.getKey(), ((TensorInfo)e.getValue()).getName());
        }
        return session;
    }
}

