/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.jna;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import ai.djl.mxnet.engine.CachedOp;
import ai.djl.mxnet.engine.MxDeviceType;
import ai.djl.mxnet.engine.MxNDArray;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.engine.MxSymbolBlock;
import ai.djl.mxnet.engine.Symbol;
import ai.djl.mxnet.jna.FunctionInfo;
import ai.djl.mxnet.jna.LibFeature;
import ai.djl.mxnet.jna.LibUtils;
import ai.djl.mxnet.jna.MxnetLibrary;
import ai.djl.mxnet.jna.NativeSize;
import ai.djl.mxnet.jna.NativeSizeByReference;
import ai.djl.mxnet.jna.PointerArray;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.Parameter;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import com.sun.jna.ptr.PointerByReference;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

public final class JnaUtils {
    public static final String[] EMPTY_ARRAY = new String[0];
    private static final String[] OP_NAME_PREFIX = new String[]{"_contrib_", "_linalg_", "_sparse_", "_image_", "_random_"};
    public static final String MXNET_THREAD_SAFE_PREDICTOR = "ai.djl.mxnet.use_thread_safe_predictor";
    private static final MxnetLibrary LIB = LibUtils.loadLibrary();
    private static final Map<String, FunctionInfo> OPS = JnaUtils.getNdArrayFunctions();
    private static final Set<String> FEATURES = JnaUtils.getFeaturesInternal();

    private JnaUtils() {
    }

    public static int getVersion() {
        IntBuffer version = IntBuffer.allocate(1);
        JnaUtils.checkCall(LIB.MXGetVersion(version));
        return version.get();
    }

    public static Set<String> getAllOpNames() {
        IntBuffer outSize = IntBuffer.allocate(1);
        PointerByReference outArray = new PointerByReference();
        JnaUtils.checkCall(LIB.MXListAllOpNames(outSize, outArray));
        int size = outSize.get();
        Pointer[] pointers = outArray.getValue().getPointerArray(0L, size);
        HashSet<String> set = new HashSet<String>();
        for (Pointer p : pointers) {
            set.add(p.getString(0L, StandardCharsets.UTF_8.name()));
        }
        return set;
    }

    public static Map<String, FunctionInfo> getNdArrayFunctions() {
        Set<String> opNames = JnaUtils.getAllOpNames();
        ConcurrentHashMap<String, FunctionInfo> map = new ConcurrentHashMap<String, FunctionInfo>();
        for (String opName : opNames) {
            PointerByReference ref = new PointerByReference();
            JnaUtils.checkCall(LIB.NNGetOpHandle(opName, ref));
            String functionName = JnaUtils.getOpNamePrefix(opName);
            map.put(functionName, JnaUtils.getFunctionByName(opName, functionName, ref.getValue()));
        }
        return map;
    }

    public static FunctionInfo op(String opName) {
        if (!OPS.containsKey(opName)) {
            throw new IllegalArgumentException("Unknown operator: " + opName);
        }
        return OPS.get(opName);
    }

    private static FunctionInfo getFunctionByName(String name, String functionName, Pointer handle) {
        String[] nameRef = new String[]{name};
        String[] description = new String[1];
        IntBuffer numArgs = IntBuffer.allocate(1);
        PointerByReference argNameRef = new PointerByReference();
        PointerByReference argTypeRef = new PointerByReference();
        PointerByReference argDescRef = new PointerByReference();
        String[] keyVarArgs = new String[1];
        String[] returnType = new String[1];
        JnaUtils.checkCall(LIB.MXSymbolGetAtomicSymbolInfo(handle, nameRef, description, numArgs, argNameRef, argTypeRef, argDescRef, keyVarArgs, returnType));
        int count = numArgs.get();
        PairList arguments = new PairList();
        if (count != 0) {
            String[] argNames = argNameRef.getValue().getStringArray(0L, count, StandardCharsets.UTF_8.name());
            String[] argTypes = argTypeRef.getValue().getStringArray(0L, count, StandardCharsets.UTF_8.name());
            for (int i = 0; i < argNames.length; ++i) {
                arguments.add((Object)argNames[i], (Object)argTypes[i]);
            }
        }
        return new FunctionInfo(handle, functionName, (PairList<String, String>)arguments);
    }

    public static int getGpuCount() {
        IntBuffer count = IntBuffer.allocate(1);
        JnaUtils.checkCall(LIB.MXGetGPUCount(count));
        return count.get();
    }

    public static long[] getGpuMemory(Device device) {
        if (!"gpu".equals(device.getDeviceType())) {
            throw new IllegalArgumentException("Only GPU device is allowed.");
        }
        int deviceId = device.getDeviceId();
        long[] ret = new long[2];
        LongBuffer freeMem = LongBuffer.wrap(ret, 0, 1);
        LongBuffer totalMem = LongBuffer.wrap(ret, 1, 1);
        JnaUtils.checkCall(LIB.MXGetGPUMemoryInformation64(deviceId, freeMem, totalMem));
        return ret;
    }

    public static Set<String> getFeatures() {
        return FEATURES;
    }

    private static Set<String> getFeaturesInternal() {
        PointerByReference ref = new PointerByReference();
        NativeSizeByReference outSize = new NativeSizeByReference();
        JnaUtils.checkCall(LIB.MXLibInfoFeatures(ref, outSize));
        int size = outSize.getValue().intValue();
        if (size == 0) {
            return Collections.emptySet();
        }
        LibFeature pointer = new LibFeature(ref.getValue());
        pointer.read();
        LibFeature[] features = (LibFeature[])pointer.toArray(size);
        HashSet<String> set = new HashSet<String>();
        for (LibFeature feature : features) {
            if (feature.getEnabled() != 1) continue;
            set.add(feature.getName());
        }
        return set;
    }

    public static int randomSeed(int seed) {
        return LIB.MXRandomSeed(seed);
    }

    public static Pointer createNdArray(Device device, Shape shape, DataType dtype, int size, boolean delayedAlloc) {
        int deviceType = MxDeviceType.toDeviceType(device);
        int deviceId = deviceType != 1 ? device.getDeviceId() : -1;
        int delay = delayedAlloc ? 1 : 0;
        PointerByReference ref = new PointerByReference();
        int[] shapeArray = Arrays.stream(shape.getShape()).mapToInt(Math::toIntExact).toArray();
        JnaUtils.checkCall(LIB.MXNDArrayCreateEx(shapeArray, size, deviceType, deviceId, delay, dtype.ordinal(), ref));
        return ref.getValue();
    }

    public static Pointer createSparseNdArray(SparseFormat fmt, Device device, Shape shape, DataType dtype, DataType[] auxDTypes, Shape[] auxShapes, boolean delayedAlloc) {
        int[] shapeArray = Arrays.stream(shape.getShape()).mapToInt(Math::toIntExact).toArray();
        int deviceType = MxDeviceType.toDeviceType(device);
        int deviceId = deviceType != 1 ? device.getDeviceId() : -1;
        int delay = delayedAlloc ? 1 : 0;
        PointerByReference ref = new PointerByReference();
        IntBuffer auxDTypesInt = IntBuffer.wrap(Arrays.stream(auxDTypes).mapToInt(Enum::ordinal).toArray());
        IntBuffer auxNDims = IntBuffer.wrap(Arrays.stream(auxShapes).mapToInt(Shape::dimension).toArray());
        int[] auxShapesInt = Arrays.stream(auxShapes).mapToInt(ele -> (int)ele.head()).toArray();
        JnaUtils.checkCall(LIB.MXNDArrayCreateSparseEx(fmt.getValue(), shapeArray, shapeArray.length, deviceType, deviceId, delay, dtype.ordinal(), auxDTypes.length, auxDTypesInt, auxNDims, auxShapesInt, ref));
        return ref.getValue();
    }

    public static void ndArraySyncCopyFromNdArray(MxNDArray dest, MxNDArray src, int location) {
        JnaUtils.checkCall(LIB.MXNDArraySyncCopyFromNDArray(dest.getHandle(), src.getHandle(), location));
    }

    public static NDList loadNdArray(MxNDManager manager, Path path, Device device) {
        IntBuffer handlesSize = IntBuffer.allocate(1);
        PointerByReference handlesRef = new PointerByReference();
        PointerByReference namesRef = new PointerByReference();
        IntBuffer namesSize = IntBuffer.allocate(1);
        JnaUtils.checkCall(LIB.MXNDArrayLoad(path.toString(), handlesSize, handlesRef, namesSize, namesRef));
        int ndArrayCount = handlesSize.get();
        int nameCount = namesSize.get();
        if (nameCount > 0 && ndArrayCount != nameCount) {
            throw new IllegalStateException("Mismatch between names and arrays in checkpoint file: " + path.toString());
        }
        Pointer[] handles = handlesRef.getValue().getPointerArray(0L, ndArrayCount);
        NDList ndList = new NDList();
        if (nameCount == 0) {
            for (Pointer handle : handles) {
                ndList.add((Object)manager.create(handle));
            }
        } else {
            String[] names = namesRef.getValue().getStringArray(0L, nameCount);
            for (int i = 0; i < ndArrayCount; ++i) {
                MxNDArray array = manager.create(handles[i]);
                array.setName(names[i]);
                ndList.add((Object)array);
            }
        }
        if (Device.cpu().equals((Object)device)) {
            return ndList;
        }
        NDList ret = ndList.toDevice(device, true);
        ndList.close();
        return ret;
    }

    public static void freeNdArray(Pointer ndArray) {
        JnaUtils.checkNDArray(ndArray, "free");
        JnaUtils.checkCall(LIB.MXNDArrayFree(ndArray));
    }

    public static void waitToRead(Pointer ndArray) {
        JnaUtils.checkNDArray(ndArray, "wait to read");
        JnaUtils.checkCall(LIB.MXNDArrayWaitToRead(ndArray));
    }

    public static void waitToWrite(Pointer ndArray) {
        JnaUtils.checkNDArray(ndArray, "wait to write");
        JnaUtils.checkCall(LIB.MXNDArrayWaitToWrite(ndArray));
    }

    public static void waitAll() {
        JnaUtils.checkCall(LIB.MXNDArrayWaitAll());
    }

    public static void syncCopyToCPU(Pointer ndArray, Pointer data, int len) {
        NativeSize size = new NativeSize(len);
        JnaUtils.checkNDArray(ndArray, "copy from");
        JnaUtils.checkNDArray(data, "copy to");
        JnaUtils.checkCall(LIB.MXNDArraySyncCopyToCPU(ndArray, data, size));
    }

    public static void syncCopyFromCPU(Pointer ndArray, Buffer data, int len) {
        NativeSize size = new NativeSize(len);
        Pointer pointer = Native.getDirectBufferPointer((Buffer)data);
        JnaUtils.checkCall(LIB.MXNDArraySyncCopyFromCPU(ndArray, pointer, size));
    }

    public static PairList<Pointer, SparseFormat> imperativeInvoke(Pointer function, PointerArray inputs, PointerByReference destRef, PairList<String, ?> params) {
        String[] values;
        String[] keys;
        if (params == null) {
            keys = EMPTY_ARRAY;
            values = EMPTY_ARRAY;
        } else {
            keys = (String[])params.keyArray((Object[])EMPTY_ARRAY);
            values = (String[])params.values().stream().map(Object::toString).toArray(String[]::new);
        }
        PointerByReference destSType = new PointerByReference();
        IntBuffer numOutputs = IntBuffer.allocate(1);
        numOutputs.put(0, 1);
        JnaUtils.checkCall(LIB.MXImperativeInvokeEx(function, inputs.numElements(), inputs, numOutputs, destRef, keys.length, keys, values, destSType));
        int numOfOutputs = numOutputs.get(0);
        Pointer[] ptrArray = destRef.getValue().getPointerArray(0L, numOfOutputs);
        int[] sTypes = destSType.getValue().getIntArray(0L, numOfOutputs);
        PairList pairList = new PairList();
        for (int i = 0; i < numOfOutputs; ++i) {
            pairList.add((Object)ptrArray[i], (Object)SparseFormat.fromValue((int)sTypes[i]));
        }
        return pairList;
    }

    public static SparseFormat getStorageType(Pointer ndArray) {
        IntBuffer type = IntBuffer.allocate(1);
        JnaUtils.checkNDArray(ndArray, "get the storage type of");
        JnaUtils.checkCall(LIB.MXNDArrayGetStorageType(ndArray, type));
        return SparseFormat.fromValue((int)type.get());
    }

    public static Device getDevice(Pointer ndArray) {
        IntBuffer deviceType = IntBuffer.allocate(1);
        IntBuffer deviceId = IntBuffer.allocate(1);
        JnaUtils.checkNDArray(ndArray, "get the device of");
        JnaUtils.checkCall(LIB.MXNDArrayGetContext(ndArray, deviceType, deviceId));
        String deviceTypeStr = MxDeviceType.fromDeviceType(deviceType.get(0));
        return Device.of((String)deviceTypeStr, (int)deviceId.get(0));
    }

    public static Shape getShape(Pointer ndArray) {
        IntBuffer dim = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();
        JnaUtils.checkNDArray(ndArray, "get the shape of");
        JnaUtils.checkCall(LIB.MXNDArrayGetShapeEx(ndArray, dim, ref));
        int nDim = dim.get();
        if (nDim == 0) {
            return new Shape(new long[0]);
        }
        int[] shape = ref.getValue().getIntArray(0L, nDim);
        return new Shape(Arrays.stream(shape).asLongStream().toArray());
    }

    public static DataType getDataType(Pointer ndArray) {
        IntBuffer dataType = IntBuffer.allocate(1);
        JnaUtils.checkNDArray(ndArray, "get the data type of");
        JnaUtils.checkCall(LIB.MXNDArrayGetDType(ndArray, dataType));
        return DataType.values()[dataType.get()];
    }

    public static boolean autogradSetIsRecording(boolean isRecording) {
        IntBuffer prev = IntBuffer.allocate(1);
        JnaUtils.checkCall(LIB.MXAutogradSetIsRecording(isRecording ? 1 : 0, prev));
        return prev.get(0) == 1;
    }

    public static boolean autogradSetTraining(boolean isTraining) {
        IntBuffer prev = IntBuffer.allocate(1);
        JnaUtils.checkCall(LIB.MXAutogradSetIsTraining(isTraining ? 1 : 0, prev));
        return prev.get(0) == 1;
    }

    public static boolean autogradIsRecording() {
        ByteBuffer isRecording = ByteBuffer.allocate(1);
        JnaUtils.checkCall(LIB.MXAutogradIsRecording(isRecording));
        return isRecording.get(0) == 1;
    }

    public static boolean autogradIsTraining() {
        ByteBuffer isTraining = ByteBuffer.allocate(1);
        JnaUtils.checkCall(LIB.MXAutogradIsTraining(isTraining));
        return isTraining.get(0) == 1;
    }

    public static void autogradMarkVariables(int numVar, Pointer varHandles, IntBuffer reqsArray, Pointer gradHandles) {
        PointerByReference varRef = new PointerByReference(varHandles);
        PointerByReference gradRef = new PointerByReference(gradHandles);
        JnaUtils.checkCall(LIB.MXAutogradMarkVariables(numVar, varRef, reqsArray, gradRef));
    }

    public static void autogradBackward(NDList array, int retainGraph) {
        JnaUtils.checkCall(LIB.MXAutogradBackward(array.size(), JnaUtils.toPointerArray(array), new PointerByReference(), retainGraph));
    }

    public static void autogradBackwardExecute(int numOutput, NDList array, NDArray outgrad, int numVariables, Pointer varHandles, int retainGraph, int createGraph, int isTrain, Pointer gradHandles, Pointer gradSparseFormat) {
        PointerByReference varRef = new PointerByReference(varHandles);
        PointerByReference gradRef = new PointerByReference(gradHandles);
        PointerByReference gradSparseFormatRef = new PointerByReference(gradSparseFormat);
        JnaUtils.checkCall(LIB.MXAutogradBackwardEx(numOutput, JnaUtils.toPointerArray(array), JnaUtils.toPointerArray(new NDList()), numVariables, varRef, retainGraph, createGraph, isTrain, gradRef, gradSparseFormatRef));
    }

    public static Pointer autogradGetSymbol(NDArray array) {
        Pointer handle = ((MxNDArray)array).getHandle();
        PointerByReference out = new PointerByReference();
        JnaUtils.checkCall(LIB.MXAutogradGetSymbol(handle, out));
        return out.getValue();
    }

    public static int isNumpyMode() {
        IntBuffer ret = IntBuffer.allocate(1);
        JnaUtils.checkCall(LIB.MXIsNumpyShape(ret));
        return ret.get();
    }

    public static void setNumpyMode(NumpyMode mode) {
        IntBuffer ret = IntBuffer.allocate(1);
        JnaUtils.checkCall(LIB.MXSetIsNumpyShape(mode.ordinal(), ret));
    }

    public static Pointer getGradient(Pointer handle) {
        PointerByReference ref = new PointerByReference();
        JnaUtils.checkNDArray(handle, "get the gradient for");
        JnaUtils.checkCall(LIB.MXNDArrayGetGrad(handle, ref));
        return ref.getValue();
    }

    public static Pointer parameterStoreCreate(String type) {
        PointerByReference ref = new PointerByReference();
        JnaUtils.checkCall(LIB.MXKVStoreCreate(type, ref));
        return ref.getValue();
    }

    public static void parameterStoreClose(Pointer handle) {
        JnaUtils.checkCall(LIB.MXKVStoreFree(handle));
    }

    public static void parameterStoreInit(Pointer handle, int num, String[] keys, NDList vals) {
        JnaUtils.checkNDArray(handle, "initialize the parameter store with");
        JnaUtils.checkCall(LIB.MXKVStoreInitEx(handle, num, keys, JnaUtils.toPointerArray(vals)));
    }

    public static void parameterStorePush(Pointer handle, int num, String[] keys, NDList vals, int priority) {
        JnaUtils.checkNDArray(handle, "push to the parameter store with");
        JnaUtils.checkCall(LIB.MXKVStorePushEx(handle, num, keys, JnaUtils.toPointerArray(vals), priority));
    }

    public static void parameterStorePull(Pointer handle, int num, int[] keys, NDList vals, int priority) {
        JnaUtils.checkNDArray(handle, "pull from the parameter store with");
        JnaUtils.checkCall(LIB.MXKVStorePull(handle, num, keys, JnaUtils.toPointerArray(vals), priority));
    }

    public static void parameterStorePull(Pointer handle, int num, String[] keys, NDList vals, int priority) {
        JnaUtils.checkNDArray(handle, "pull from the parameter store with");
        JnaUtils.checkCall(LIB.MXKVStorePullEx(handle, num, keys, JnaUtils.toPointerArray(vals), priority));
    }

    public static void parameterStorePushPull(Pointer handle, int inputNum, String[] inputKeys, int outputNum, String[] outputKey, NDList inputs, NDList outputs, int priority) {
        JnaUtils.checkNDArray(handle, "push from the parameter store with");
        JnaUtils.checkCall(LIB.MXKVStorePushPullEx(handle, inputNum, inputKeys, outputNum, outputKey, JnaUtils.toPointerArray(inputs), JnaUtils.toPointerArray(outputs), priority));
    }

    public static void parameterStoreSetUpdater(Pointer handle, MxnetLibrary.MXKVStoreUpdater updater, MxnetLibrary.MXKVStoreStrUpdater stringUpdater, Pointer updaterHandle) {
        JnaUtils.checkCall(LIB.MXKVStoreSetUpdaterEx(handle, updater, stringUpdater, updaterHandle));
    }

    public static void parameterStoreSetUpdater(Pointer handle, MxnetLibrary.MXKVStoreUpdater updater, Pointer updaterHandle) {
        JnaUtils.checkCall(LIB.MXKVStoreSetUpdater(handle, updater, updaterHandle));
    }

    public static Pointer getSymbolOutput(Pointer symbol, int index) {
        PointerByReference ref = new PointerByReference();
        JnaUtils.checkCall(LIB.MXSymbolGetOutput(symbol, index, ref));
        return ref.getValue();
    }

    public static String[] listSymbolOutputs(Pointer symbol) {
        IntBuffer size = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();
        JnaUtils.checkCall(LIB.MXSymbolListOutputs(symbol, size, ref));
        return JnaUtils.toStringArray(ref, size.get());
    }

    public static void freeSymbol(Pointer symbol) {
        JnaUtils.checkCall(LIB.MXSymbolFree(symbol));
    }

    public static String[] listSymbolNames(Pointer symbol) {
        IntBuffer size = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();
        JnaUtils.checkCall(LIB.NNSymbolListInputNames(symbol, 0, size, ref));
        return JnaUtils.toStringArray(ref, size.get());
    }

    public static String[] listSymbolArguments(Pointer symbol) {
        IntBuffer size = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();
        JnaUtils.checkCall(LIB.MXSymbolListArguments(symbol, size, ref));
        return JnaUtils.toStringArray(ref, size.get());
    }

    public static String[] listSymbolAuxiliaryStates(Pointer symbol) {
        IntBuffer size = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();
        JnaUtils.checkCall(LIB.MXSymbolListAuxiliaryStates(symbol, size, ref));
        return JnaUtils.toStringArray(ref, size.get());
    }

    public static Pointer getSymbolInternals(Pointer symbol) {
        PointerByReference ref = new PointerByReference();
        JnaUtils.checkCall(LIB.MXSymbolGetInternals(symbol, ref));
        return ref.getValue();
    }

    public static Pointer createSymbolFromFile(String path) {
        PointerByReference ref = new PointerByReference();
        JnaUtils.checkCall(LIB.MXSymbolCreateFromFile(path, ref));
        return ref.getValue();
    }

    private static List<Shape> recoverShape(NativeSizeByReference size, PointerByReference nDim, PointerByReference data) {
        int shapeLength = (int)size.getValue().longValue();
        if (shapeLength == 0) {
            return new ArrayList<Shape>();
        }
        int[] dims = nDim.getValue().getIntArray(0L, shapeLength);
        int flattenedLength = 0;
        for (int dim : dims) {
            flattenedLength += dim;
        }
        long[] flattenedShapes = data.getValue().getPointer(0L).getLongArray(0L, flattenedLength);
        int idx = 0;
        ArrayList<Shape> result = new ArrayList<Shape>();
        for (int dim : dims) {
            long[] shape = new long[dim];
            System.arraycopy(flattenedShapes, idx, shape, 0, dim);
            idx += dim;
            result.add(new Shape(shape));
        }
        return result;
    }

    public static List<List<Shape>> inferShape(Symbol symbol, PairList<String, Shape> args) {
        Pointer handler = symbol.getHandle();
        int numArgs = args.size();
        String[] keys = args.keys().toArray(new String[0]);
        long[] indPtr = new long[numArgs + 1];
        Shape flattened = new Shape(new long[0]);
        indPtr[0] = 0L;
        for (int i = 0; i < args.size(); ++i) {
            Shape shape = (Shape)args.valueAt(i);
            indPtr[i + 1] = shape.dimension();
            flattened = flattened.addAll(shape);
        }
        long[] flattenedShapeArray = flattened.getShape();
        NativeSizeByReference inShapeSize = new NativeSizeByReference();
        PointerByReference inShapeNDim = new PointerByReference();
        PointerByReference inShapeData = new PointerByReference();
        NativeSizeByReference outShapeSize = new NativeSizeByReference();
        PointerByReference outShapeNDim = new PointerByReference();
        PointerByReference outShapeData = new PointerByReference();
        NativeSizeByReference auxShapeSize = new NativeSizeByReference();
        PointerByReference auxShapeNDim = new PointerByReference();
        PointerByReference auxShapeData = new PointerByReference();
        IntBuffer complete = IntBuffer.allocate(1);
        JnaUtils.checkCall(LIB.MXSymbolInferShapeEx64(handler, numArgs, keys, indPtr, flattenedShapeArray, inShapeSize, inShapeNDim, inShapeData, outShapeSize, outShapeNDim, outShapeData, auxShapeSize, auxShapeNDim, auxShapeData, complete));
        if (complete.get() != 0) {
            return Arrays.asList(JnaUtils.recoverShape(inShapeSize, inShapeNDim, inShapeData), JnaUtils.recoverShape(outShapeSize, outShapeNDim, outShapeData), JnaUtils.recoverShape(auxShapeSize, auxShapeNDim, auxShapeData));
        }
        return null;
    }

    public static CachedOp createCachedOp(MxSymbolBlock block, MxNDManager manager, boolean training) {
        Symbol symbol = block.getSymbol();
        List<Parameter> parameters = block.getAllParameters();
        PairList dataIndices = new PairList();
        ArrayList<Integer> paramIndices = new ArrayList<Integer>();
        int index = 0;
        for (Parameter parameter : parameters) {
            if (parameter.isInitialized()) {
                paramIndices.add(index);
            } else {
                dataIndices.add((Object)parameter.getName(), (Object)index);
            }
            ++index;
        }
        Pointer symbolHandle = symbol.getHandle();
        PointerByReference ref = new PointerByReference();
        String[] keys = new String[]{"data_indices", "param_indices", "static_alloc", "static_shape"};
        String[] values = new String[]{dataIndices.values().toString(), ((Object)paramIndices).toString(), "1", "1"};
        if (training) {
            Preconditions.checkArgument((!JnaUtils.useThreadSafePredictor() ? 1 : 0) != 0, (String)"thread-safe Predictor doesn't support training.");
        }
        JnaUtils.checkCall(LIB.MXCreateCachedOpEX(symbolHandle, keys.length, keys, values, ref, (byte)(JnaUtils.useThreadSafePredictor() ? 1 : 0)));
        return new CachedOp(ref.getValue(), manager, parameters, paramIndices, (PairList<String, Integer>)dataIndices);
    }

    public static void freeCachedOp(Pointer handle) {
        JnaUtils.checkCall(LIB.MXFreeCachedOp(handle));
    }

    public static MxNDArray[] cachedOpInvoke(MxNDManager manager, Pointer cachedOpHandle, MxNDArray[] inputs) {
        Pointer[] inputHandles = new Pointer[inputs.length];
        for (int i = 0; i < inputs.length; ++i) {
            inputHandles[i] = inputs[i].getHandle();
        }
        PointerArray array = new PointerArray(inputHandles);
        IntBuffer buf = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();
        PointerByReference outSTypeRef = new PointerByReference();
        JnaUtils.checkCall(LIB.MXInvokeCachedOpEx(cachedOpHandle, inputs.length, (Pointer)array, buf, ref, outSTypeRef));
        int numOutputs = buf.get();
        Pointer[] ptrArray = ref.getValue().getPointerArray(0L, numOutputs);
        int[] sTypes = outSTypeRef.getValue().getIntArray(0L, numOutputs);
        MxNDArray[] output = new MxNDArray[numOutputs];
        for (int i = 0; i < numOutputs; ++i) {
            output[i] = sTypes[i] != 0 ? manager.create(ptrArray[i], SparseFormat.fromValue((int)sTypes[i])) : manager.create(ptrArray[i]);
        }
        return output;
    }

    public static boolean useThreadSafePredictor() {
        return Boolean.getBoolean(MXNET_THREAD_SAFE_PREDICTOR);
    }

    public static void checkCall(int ret) {
        if (ret != 0) {
            throw new EngineException("MXNet engine call failed: " + JnaUtils.getLastError());
        }
    }

    static PointerArray toPointerArray(NDList vals) {
        Pointer[] valPointers = new Pointer[vals.size()];
        for (int i = 0; i < vals.size(); ++i) {
            valPointers[i] = ((MxNDArray)vals.get(i)).getHandle();
        }
        return new PointerArray(valPointers);
    }

    static PointerArray toPointerArray(NDArray[] vals) {
        Pointer[] valPointers = new Pointer[vals.length];
        for (int i = 0; i < vals.length; ++i) {
            valPointers[i] = ((MxNDArray)vals[i]).getHandle();
        }
        return new PointerArray(valPointers);
    }

    private static void checkNDArray(Pointer pointer, String msg) {
        if (pointer == null) {
            throw new IllegalArgumentException("Tried to " + msg + " an MXNet NDArray that was already closed");
        }
    }

    private static String getLastError() {
        return LIB.MXGetLastError();
    }

    private static String[] toStringArray(PointerByReference ref, int size) {
        if (size == 0) {
            return new String[0];
        }
        Pointer[] pointers = ref.getValue().getPointerArray(0L, size);
        String[] arr = new String[size];
        for (int i = 0; i < size; ++i) {
            arr[i] = pointers[i].getString(0L, StandardCharsets.UTF_8.name());
        }
        return arr;
    }

    private static String getOpNamePrefix(String name) {
        for (String prefix : OP_NAME_PREFIX) {
            if (!name.startsWith(prefix)) continue;
            return name.substring(prefix.length());
        }
        return name;
    }

    public static enum NumpyMode {
        OFF,
        THREAD_LOCAL_ON,
        GLOBAL_ON;

    }
}

