/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.tensorflow.engine.javacpp;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.tensorflow.engine.SavedModelBundle;
import ai.djl.tensorflow.engine.TfDataType;
import ai.djl.util.Pair;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.internal.c_api.AbstractTFE_Context;
import org.tensorflow.internal.c_api.AbstractTFE_TensorHandle;
import org.tensorflow.internal.c_api.AbstractTF_Graph;
import org.tensorflow.internal.c_api.AbstractTF_Tensor;
import org.tensorflow.internal.c_api.TFE_Context;
import org.tensorflow.internal.c_api.TFE_ContextOptions;
import org.tensorflow.internal.c_api.TFE_TensorHandle;
import org.tensorflow.internal.c_api.TF_Buffer;
import org.tensorflow.internal.c_api.TF_Graph;
import org.tensorflow.internal.c_api.TF_Operation;
import org.tensorflow.internal.c_api.TF_Output;
import org.tensorflow.internal.c_api.TF_Session;
import org.tensorflow.internal.c_api.TF_SessionOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_TString;
import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.GPUOptions;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.RunOptions;

public final class JavacppUtils {
    private static final Pattern DEVICE_PATTERN = Pattern.compile(".*device:([A-Z]PU):(\\d+)");

    private JavacppUtils() {
    }

    public static SavedModelBundle loadSavedModelBundle(String exportDir, String[] tags, ConfigProto config, RunOptions runOptions) {
        Throwable throwable = null;
        try (PointerScope ignored = new PointerScope(new Class[0]);){
            TF_Status status = TF_Status.newStatus();
            TF_SessionOptions opts = TF_SessionOptions.newSessionOptions();
            if (config != null) {
                BytePointer configBytes = new BytePointer(config.toByteArray());
                tensorflow.TF_SetConfig((TF_SessionOptions)opts, (Pointer)configBytes, (long)configBytes.capacity(), (TF_Status)status);
                status.throwExceptionIfNotOK();
            }
            TF_Buffer runOpts = TF_Buffer.newBufferFromString((Message)runOptions);
            TF_Graph graphHandle = (TF_Graph)AbstractTF_Graph.newGraph().retainReference();
            TF_Buffer metaGraphDef = TF_Buffer.newBuffer();
            TF_Session sessionHandle = tensorflow.TF_LoadSessionFromSavedModel((TF_SessionOptions)opts, (TF_Buffer)runOpts, (BytePointer)new BytePointer(exportDir), (PointerPointer)new PointerPointer(tags), (int)tags.length, (TF_Graph)graphHandle, (TF_Buffer)metaGraphDef, (TF_Status)status);
            status.throwExceptionIfNotOK();
            try {
                SavedModelBundle savedModelBundle = new SavedModelBundle(graphHandle, sessionHandle, MetaGraphDef.parseFrom((ByteBuffer)metaGraphDef.dataAsByteBuffer()));
                return savedModelBundle;
            }
            catch (InvalidProtocolBufferException e) {
                try {
                    throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", (Throwable)e);
                }
                catch (Throwable throwable2) {
                    throwable = throwable2;
                    throw throwable2;
                }
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static TF_Operation getGraphOpByName(TF_Graph graphHandle, String operation) {
        TF_Operation opHandle;
        TF_Graph tF_Graph = graphHandle;
        synchronized (tF_Graph) {
            opHandle = tensorflow.TF_GraphOperationByName((TF_Graph)graphHandle, (String)operation);
        }
        if (opHandle == null || opHandle.isNull()) {
            throw new IllegalArgumentException("No Operation named [" + operation + "] in the Graph");
        }
        return opHandle;
    }

    public static Pair<TF_Operation, Integer> getGraphOperationByName(TF_Graph graphHandle, String operation) {
        int colon = operation.lastIndexOf(58);
        if (colon == -1 || colon == operation.length() - 1) {
            return new Pair((Object)JavacppUtils.getGraphOpByName(graphHandle, operation), (Object)0);
        }
        try {
            String op = operation.substring(0, colon);
            int index = Integer.parseInt(operation.substring(colon + 1));
            return new Pair((Object)JavacppUtils.getGraphOpByName(graphHandle, op), (Object)index);
        }
        catch (NumberFormatException e) {
            return new Pair((Object)JavacppUtils.getGraphOpByName(graphHandle, operation), (Object)0);
        }
    }

    public static TF_Tensor[] runSession(TF_Session handle, RunOptions runOptions, TF_Tensor[] inputTensorHandles, TF_Operation[] inputOpHandles, int[] inputOpIndices, TF_Operation[] outputOpHandles, int[] outputOpIndices, TF_Operation[] targetOpHandles) {
        int numInputs = inputTensorHandles.length;
        int numOutputs = outputOpHandles.length;
        int numTargets = targetOpHandles.length;
        try (PointerScope ignored = new PointerScope(new Class[0]);){
            int i;
            TF_Output inputs = new TF_Output((long)numInputs);
            PointerPointer inputValues = new PointerPointer((long)numInputs);
            TF_Output outputs = new TF_Output((long)numOutputs);
            PointerPointer outputValues = new PointerPointer((long)numOutputs);
            PointerPointer targets = new PointerPointer((long)numTargets);
            for (i = 0; i < numInputs; ++i) {
                inputValues.put((long)i, (Pointer)inputTensorHandles[i]);
            }
            for (i = 0; i < numInputs; ++i) {
                inputs.position((long)i).oper(inputOpHandles[i]).index(inputOpIndices[i]);
            }
            inputs.position(0L);
            for (i = 0; i < numOutputs; ++i) {
                outputs.position((long)i).oper(outputOpHandles[i]).index(outputOpIndices[i]);
            }
            outputs.position(0L);
            for (i = 0; i < numTargets; ++i) {
                targets.put((long)i, (Pointer)targetOpHandles[i]);
            }
            TF_Status status = TF_Status.newStatus();
            TF_Buffer runOpts = TF_Buffer.newBufferFromString((Message)runOptions);
            tensorflow.TF_SessionRun((TF_Session)handle, (TF_Buffer)runOpts, (TF_Output)inputs, (PointerPointer)inputValues, (int)numInputs, (TF_Output)outputs, (PointerPointer)outputValues, (int)numOutputs, (PointerPointer)targets, (int)numTargets, null, (TF_Status)status);
            status.throwExceptionIfNotOK();
            TF_Tensor[] ret = new TF_Tensor[numOutputs];
            for (int i2 = 0; i2 < numOutputs; ++i2) {
                ret[i2] = (TF_Tensor)((TF_Tensor)outputValues.get(TF_Tensor.class, (long)i2)).withDeallocator().retainReference();
            }
            TF_Tensor[] tF_TensorArray = ret;
            return tF_TensorArray;
        }
    }

    public static TFE_Context createEagerSession(boolean async, int devicePlacementPolicy, ConfigProto config) {
        try (PointerScope ignored = new PointerScope(new Class[0]);){
            TFE_ContextOptions opts = TFE_ContextOptions.newContextOptions();
            TF_Status status = TF_Status.newStatus();
            if (config != null) {
                BytePointer configBytes = new BytePointer(config.toByteArray());
                tensorflow.TFE_ContextOptionsSetConfig((TFE_ContextOptions)opts, (Pointer)configBytes, (long)configBytes.capacity(), (TF_Status)status);
                status.throwExceptionIfNotOK();
            }
            tensorflow.TFE_ContextOptionsSetAsync((TFE_ContextOptions)opts, (byte)((byte)(async ? 1 : 0)));
            tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy((TFE_ContextOptions)opts, (int)devicePlacementPolicy);
            TFE_Context context = AbstractTFE_Context.newContext((TFE_ContextOptions)opts, (TF_Status)status);
            status.throwExceptionIfNotOK();
            TFE_Context tFE_Context = (TFE_Context)context.retainReference();
            return tFE_Context;
        }
    }

    public static Device getDevice(TFE_TensorHandle handle) {
        try (PointerScope ignored = new PointerScope(new Class[0]);){
            TF_Status status = TF_Status.newStatus();
            BytePointer pointer = tensorflow.TFE_TensorHandleDeviceName((TFE_TensorHandle)handle, (TF_Status)status);
            String device = new String(pointer.getStringBytes(), StandardCharsets.UTF_8);
            Device device2 = JavacppUtils.fromTfDevice(device);
            return device2;
        }
    }

    public static DataType getDataType(TFE_TensorHandle handle) {
        return TfDataType.fromTf(tensorflow.TFE_TensorHandleDataType((TFE_TensorHandle)handle));
    }

    public static Shape getShape(TFE_TensorHandle handle) {
        try (PointerScope ignored = new PointerScope(new Class[0]);){
            TF_Status status = TF_Status.newStatus();
            int numDims = tensorflow.TFE_TensorHandleNumDims((TFE_TensorHandle)handle, (TF_Status)status);
            status.throwExceptionIfNotOK();
            long[] shapeArr = new long[numDims];
            for (int i = 0; i < numDims; ++i) {
                shapeArr[i] = tensorflow.TFE_TensorHandleDim((TFE_TensorHandle)handle, (int)i, (TF_Status)status);
                status.throwExceptionIfNotOK();
            }
            Shape shape = new Shape(shapeArr);
            return shape;
        }
    }

    public static TF_Tensor createEmptyTFTensor(Shape shape, DataType dataType) {
        long numBytes;
        long[] dims;
        int dType = TfDataType.toTf(dataType);
        TF_Tensor tensor = AbstractTF_Tensor.allocateTensor((int)dType, (long[])(dims = shape.getShape()), (long)(numBytes = (long)dataType.getNumOfBytes() * shape.size()));
        if (tensor == null || tensor.isNull()) {
            throw new IllegalStateException("unable to allocate memory for the Tensor");
        }
        return tensor;
    }

    public static TFE_TensorHandle createEmptyTFETensor(Shape shape, DataType dataType) {
        try (PointerScope ignored = new PointerScope(new Class[0]);){
            TF_Tensor tensor = JavacppUtils.createEmptyTFTensor(shape, dataType);
            TF_Status status = TF_Status.newStatus();
            TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor((TF_Tensor)tensor, (TF_Status)status);
            status.throwExceptionIfNotOK();
            TFE_TensorHandle tFE_TensorHandle = (TFE_TensorHandle)handle.retainReference();
            return tFE_TensorHandle;
        }
    }

    public static Pair<TF_Tensor, TFE_TensorHandle> createStringTensor(String src) {
        int dType = TfDataType.toTf(DataType.STRING);
        long[] dims = new long[]{};
        long numBytes = Loader.sizeof(TF_TString.class);
        try (PointerScope ignored = new PointerScope(new Class[0]);){
            TF_Tensor tensor = AbstractTF_Tensor.allocateTensor((int)dType, (long[])dims, (long)numBytes);
            Pointer pointer = tensorflow.TF_TensorData((TF_Tensor)tensor).capacity(numBytes);
            TF_TString data = (TF_TString)new TF_TString(pointer).capacity(pointer.position() + 1L);
            byte[] buf = src.getBytes(StandardCharsets.UTF_8);
            tensorflow.TF_TString_Copy((TF_TString)data, (BytePointer)new BytePointer(buf), (long)buf.length);
            TF_Status status = TF_Status.newStatus();
            TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor((TF_Tensor)tensor, (TF_Status)status);
            status.throwExceptionIfNotOK();
            handle.retainReference();
            tensor.retainReference();
            Pair pair = new Pair((Object)tensor, (Object)handle);
            return pair;
        }
    }

    public static TFE_TensorHandle createTFETensorFromByteBuffer(ByteBuffer buf, Shape shape, DataType dataType) {
        int dType = TfDataType.toTf(dataType);
        long[] dims = shape.getShape();
        long numBytes = dataType == DataType.STRING ? (long)(buf.remaining() + 1) : shape.size() * (long)dataType.getNumOfBytes();
        try (PointerScope ignored = new PointerScope(new Class[0]);){
            TF_Tensor tensor = AbstractTF_Tensor.allocateTensor((int)dType, (long[])dims, (long)numBytes);
            Pointer pointer = tensorflow.TF_TensorData((TF_Tensor)tensor).capacity(numBytes);
            pointer.asByteBuffer().put(buf);
            TF_Status status = TF_Status.newStatus();
            TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor((TF_Tensor)tensor, (TF_Status)status);
            status.throwExceptionIfNotOK();
            TFE_TensorHandle tFE_TensorHandle = (TFE_TensorHandle)handle.retainReference();
            return tFE_TensorHandle;
        }
    }

    public static TF_Tensor resolveTFETensor(TFE_TensorHandle handle) {
        try (PointerScope ignored = new PointerScope(new Class[0]);){
            TF_Status status = TF_Status.newStatus();
            TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve((TFE_TensorHandle)handle, (TF_Status)status).withDeallocator();
            status.throwExceptionIfNotOK();
            TF_Tensor tF_Tensor = (TF_Tensor)tensor.retainReference();
            return tF_Tensor;
        }
    }

    public static TFE_TensorHandle createTFETensor(TF_Tensor handle) {
        try (PointerScope ignored = new PointerScope(new Class[0]);){
            TF_Status status = TF_Status.newStatus();
            TFE_TensorHandle tensor = AbstractTFE_TensorHandle.newTensor((TF_Tensor)handle, (TF_Status)status);
            status.throwExceptionIfNotOK();
            TFE_TensorHandle tFE_TensorHandle = (TFE_TensorHandle)tensor.retainReference();
            return tFE_TensorHandle;
        }
    }

    public static String[] getString(TFE_TensorHandle handle, int count, Charset charset) {
        try (PointerScope ignored = new PointerScope(new Class[0]);){
            TF_Status status = TF_Status.newStatus();
            TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve((TFE_TensorHandle)handle, (TF_Status)status);
            status.throwExceptionIfNotOK();
            long tensorSize = tensorflow.TF_TensorByteSize((TF_Tensor)tensor);
            Pointer pointer = tensorflow.TF_TensorData((TF_Tensor)tensor).capacity(tensorSize);
            TF_TString data = (TF_TString)new TF_TString(pointer).capacity(pointer.position() + (long)count);
            String[] ret = new String[count];
            for (int i = 0; i < count; ++i) {
                TF_TString tstring = data.getPointer((long)i);
                long size = tensorflow.TF_TString_GetSize((TF_TString)tstring);
                BytePointer bp = tensorflow.TF_TString_GetDataPointer((TF_TString)tstring).capacity(size);
                ret[i] = bp.getString(charset);
            }
            tensorflow.TF_DeleteTensor((TF_Tensor)tensor);
            String[] stringArray = ret;
            return stringArray;
        }
    }

    public static ByteBuffer getByteBuffer(TFE_TensorHandle handle) {
        try (PointerScope ignored = new PointerScope(new Class[0]);){
            TF_Status status = TF_Status.newStatus();
            TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve((TFE_TensorHandle)handle, (TF_Status)status).withDeallocator();
            status.throwExceptionIfNotOK();
            Pointer pointer = tensorflow.TF_TensorData((TF_Tensor)tensor).capacity(tensorflow.TF_TensorByteSize((TF_Tensor)tensor));
            ByteBuffer buf = pointer.asByteBuffer();
            ByteBuffer ret = ByteBuffer.allocate(buf.capacity());
            buf.rewind();
            ret.put(buf);
            ret.flip();
            ByteBuffer byteBuffer = ret.order(ByteOrder.nativeOrder());
            return byteBuffer;
        }
    }

    public static TFE_TensorHandle toDevice(TFE_TensorHandle handle, TFE_Context eagerSessionHandle, Device device) {
        try (PointerScope ignored = new PointerScope(new Class[0]);){
            String deviceName = JavacppUtils.toTfDevice(device);
            TF_Status status = TF_Status.newStatus();
            TFE_TensorHandle newHandle = tensorflow.TFE_TensorHandleCopyToDevice((TFE_TensorHandle)handle, (TFE_Context)eagerSessionHandle, (String)deviceName, (TF_Status)status);
            status.throwExceptionIfNotOK();
            TFE_TensorHandle tFE_TensorHandle = newHandle;
            return tFE_TensorHandle;
        }
    }

    public static ConfigProto getSessionConfig() {
        Integer interop = Integer.getInteger("ai.djl.tensorflow.num_interop_threads");
        Integer intraop = Integer.getInteger("ai.djl.tensorflow.num_intraop_threads");
        ConfigProto.Builder configBuilder = ConfigProto.newBuilder();
        if (interop != null) {
            configBuilder.setInterOpParallelismThreads(interop.intValue());
        }
        if (intraop != null) {
            configBuilder.setIntraOpParallelismThreads(intraop.intValue());
        }
        GPUOptions gpuOptions = GPUOptions.newBuilder().setVisibleDeviceList("0").build();
        configBuilder.setGpuOptions(gpuOptions);
        return configBuilder.build();
    }

    public static Device fromTfDevice(String device) {
        Matcher m = DEVICE_PATTERN.matcher(device);
        if (m.matches()) {
            if ("CPU".equals(m.group(1))) {
                return Device.cpu();
            }
            if ("GPU".equals(m.group(1))) {
                return Device.of((String)"gpu", (int)Integer.parseInt(m.group(2)));
            }
        }
        throw new EngineException("Unknown device type to TensorFlow Engine: " + device);
    }

    public static String toTfDevice(Device device) {
        if (device.getDeviceType().equals("cpu")) {
            return "/device:CPU:0";
        }
        if (device.getDeviceType().equals("gpu")) {
            return "/device:GPU:" + device.getDeviceId();
        }
        throw new EngineException("Unknown device type to TensorFlow Engine: " + device);
    }
}

