/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.Device;
import ai.djl.mxnet.engine.MxNDArray;
import ai.djl.mxnet.engine.MxOpParams;
import ai.djl.mxnet.engine.MxSparseNDArray;
import ai.djl.mxnet.engine.MxTrainer;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.util.PairList;
import com.sun.jna.Pointer;
import java.lang.ref.Reference;
import java.lang.ref.WeakReference;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MxNDManager
implements NDManager {
    private static final Logger logger = LoggerFactory.getLogger(MxTrainer.class);
    private static final MxNDManager SYSTEM_MANAGER = new SystemManager();
    private static final NDArray[] EMPTY = new NDArray[0];
    private NDManager parent;
    private String uid;
    private Device device;
    private Map<String, Reference<AutoCloseable>> resources;
    private AtomicBoolean closed = new AtomicBoolean(false);

    private MxNDManager(NDManager parent, Device device) {
        this.parent = parent;
        this.device = Device.defaultIfNull((Device)device);
        this.resources = new ConcurrentHashMap<String, Reference<AutoCloseable>>();
        this.uid = UUID.randomUUID().toString();
    }

    static MxNDManager getSystemManager() {
        return SYSTEM_MANAGER;
    }

    public ByteBuffer allocateDirect(int capacity) {
        return ByteBuffer.allocateDirect(capacity).order(ByteOrder.nativeOrder());
    }

    public MxNDArray create(Pointer handle) {
        MxNDArray array = new MxNDArray(this, handle);
        this.attach(array.getUid(), array);
        return array;
    }

    public MxSparseNDArray create(Pointer handle, SparseFormat fmt) {
        MxSparseNDArray array = new MxSparseNDArray(this, handle, fmt);
        this.attach(array.getUid(), array);
        return array;
    }

    public MxNDArray create(Shape shape, DataType dataType, Device dev) {
        dev = Device.defaultIfNull((Device)dev, (Device)this.device);
        Pointer handle = JnaUtils.createNdArray(dev, shape, dataType, shape.dimension(), false);
        MxNDArray array = new MxNDArray(this, handle, dev, shape, dataType);
        this.attach(array.getUid(), array);
        return array;
    }

    public MxSparseNDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape, Device dev) {
        dev = Device.defaultIfNull((Device)dev, (Device)this.device);
        SparseFormat fmt = SparseFormat.CSR;
        DataType dataType = DataType.fromBuffer((Buffer)data);
        MxNDArray indptrNd = this.create(new Shape(new long[]{indptr.length}), DataType.INT64, dev);
        indptrNd.set(indptr);
        MxNDArray indicesNd = this.create(new Shape(new long[]{indices.length}), DataType.INT64, dev);
        indicesNd.set(indices);
        Pointer handle = JnaUtils.createSparseNdArray(fmt, dev, shape, dataType, new DataType[]{indptrNd.getDataType(), indicesNd.getDataType()}, new Shape[]{indptrNd.getShape(), indicesNd.getShape()}, false);
        MxSparseNDArray sparse = this.create(handle, fmt);
        MxNDArray dataNd = this.create(new Shape(new long[]{data.remaining()}), dataType, dev);
        dataNd.set(data);
        JnaUtils.ndArraySyncCopyFromNdArray(sparse, dataNd, -1);
        JnaUtils.ndArraySyncCopyFromNdArray(sparse, indptrNd, 0);
        JnaUtils.ndArraySyncCopyFromNdArray(sparse, indicesNd, 1);
        return sparse;
    }

    public MxSparseNDArray createRowSparse(Buffer data, Shape dataShape, long[] indices, Shape shape, Device dev) {
        dev = Device.defaultIfNull((Device)dev, (Device)this.device);
        SparseFormat fmt = SparseFormat.ROW_SPARSE;
        DataType dataType = DataType.fromBuffer((Buffer)data);
        MxNDArray indicesNd = this.create(new Shape(new long[]{indices.length}), DataType.INT64, dev);
        indicesNd.set(indices);
        Pointer handle = JnaUtils.createSparseNdArray(fmt, dev, shape, dataType, new DataType[]{indicesNd.getDataType()}, new Shape[]{indicesNd.getShape()}, false);
        MxSparseNDArray sparse = this.create(handle, fmt);
        MxNDArray dataNd = this.create(dataShape, dataType, dev);
        dataNd.set(data);
        JnaUtils.ndArraySyncCopyFromNdArray(sparse, dataNd, -1);
        JnaUtils.ndArraySyncCopyFromNdArray(sparse, indicesNd, 0);
        return sparse;
    }

    public NDArray zeros(Shape shape, DataType dataType, Device dev) {
        return this.fill("_npi_zeros", dev, shape, dataType);
    }

    public NDArray ones(Shape shape, DataType dataType, Device dev) {
        return this.fill("_npi_ones", dev, shape, dataType);
    }

    public NDArray arange(Number start, Number stop, Number step, DataType dataType, Device dev) {
        MxOpParams params = new MxOpParams();
        params.addParam("start", start);
        params.addParam("stop", stop);
        params.addParam("step", step);
        if (dataType != null) {
            params.setDataType(dataType);
        }
        params.setDevice(Device.defaultIfNull((Device)dev, (Device)this.device));
        return this.invoke("_npi_arange", params);
    }

    public NDArray eye(int rows, int cols, int k, DataType dataType, Device dev) {
        MxOpParams params = new MxOpParams();
        params.addParam("N", rows);
        params.addParam("M", cols);
        params.addParam("k", k);
        params.setDataType(dataType);
        params.setDevice(Device.defaultIfNull((Device)dev, (Device)this.device));
        return this.invoke("_npi_eye", params);
    }

    public NDArray linspace(Number start, Number stop, int num, boolean endpoint, Device dev) {
        if (num < 0) {
            throw new IllegalArgumentException("Num argument must be non-negative");
        }
        MxOpParams params = new MxOpParams();
        params.addParam("start", start);
        params.addParam("stop", stop);
        params.addParam("num", num);
        params.addParam("endpoint", endpoint);
        params.setDevice(Device.defaultIfNull((Device)dev, (Device)this.device));
        return this.invoke("_npi_linspace", params);
    }

    public NDArray randomUniform(Number low, Number high, Shape shape, DataType dataType, Device dev) {
        MxOpParams params = new MxOpParams();
        params.addParam("low", low);
        params.addParam("high", high);
        params.addParam("size", shape);
        params.setDevice(Device.defaultIfNull((Device)dev, (Device)this.device));
        if (dataType != null) {
            params.setDataType(dataType);
        }
        return this.invoke("_npi_uniform", params);
    }

    public NDArray randomNormal(Number loc, Number scale, Shape shape, DataType dataType, Device dev) {
        MxOpParams params = new MxOpParams();
        params.addParam("loc", loc);
        params.addParam("scale", scale);
        params.addParam("size", shape);
        params.setDevice(Device.defaultIfNull((Device)dev, (Device)this.device));
        if (dataType != null) {
            params.setDataType(dataType);
        }
        return this.invoke("_npi_normal", params);
    }

    public NDArray randomMultinomial(int n, NDArray pValues, Shape shape) {
        MxOpParams params = new MxOpParams();
        params.addParam("n", n);
        params.addParam("size", shape);
        return this.invoke("_npi_multinomial", pValues, params);
    }

    public NDArray randomMultinomial(int n, NDArray pValues) {
        MxOpParams params = new MxOpParams();
        params.addParam("n", n);
        return this.invoke("_npi_multinomial", pValues, params);
    }

    public NDManager getParentManager() {
        return this.parent;
    }

    public MxNDManager newSubManager() {
        return this.newSubManager(this.device);
    }

    public MxNDManager newSubManager(Device dev) {
        MxNDManager manager = new MxNDManager(this, dev);
        this.attach(manager.uid, (AutoCloseable)((Object)manager));
        return manager;
    }

    public Device getDevice() {
        return this.device;
    }

    public synchronized void attach(String resourceId, AutoCloseable resource) {
        if (this.closed.get()) {
            throw new IllegalStateException("NDManager has been closed already.");
        }
        WeakReference<AutoCloseable> ref = new WeakReference<AutoCloseable>(resource);
        this.resources.put(resourceId, ref);
    }

    public synchronized void detach(String resourceId) {
        if (this.closed.get()) {
            return;
        }
        this.resources.remove(resourceId);
    }

    public void invoke(String operation, NDArray[] src, NDArray[] dest, PairList<String, ?> params) {
        JnaUtils.op(operation).invoke(this, src, dest, params);
    }

    public NDList invoke(String operation, NDList src, PairList<String, ?> params) {
        return new NDList(JnaUtils.op(operation).invoke((NDManager)this, (NDArray[])src.toArray((Object[])EMPTY), params));
    }

    public void invoke(String operation, NDList src, NDList dest, PairList<String, ?> params) {
        this.invoke(operation, (NDArray[])src.toArray((Object[])EMPTY), (NDArray[])dest.toArray((Object[])EMPTY), params);
    }

    public NDArray invoke(String operation, NDArray[] src, PairList<String, ?> params) {
        return JnaUtils.op(operation).invoke((NDManager)this, src, params)[0];
    }

    public NDArray invoke(String operation, NDArray src, PairList<String, ?> params) {
        return this.invoke(operation, new NDArray[]{src}, params);
    }

    public NDArray invoke(String operation, PairList<String, ?> params) {
        return this.invoke(operation, EMPTY, params);
    }

    public String toString() {
        String parentUID = this.parent == null ? "No Parent" : ((MxNDManager)this.parent).uid;
        return "UID: " + this.uid + " Parent UID: " + parentUID + " isOpen: " + this.isOpen() + " Resource size: " + this.resources.size();
    }

    public synchronized void close() {
        if (!this.closed.getAndSet(true)) {
            for (Reference<AutoCloseable> resource : this.resources.values()) {
                AutoCloseable closeable = resource.get();
                if (closeable == null) continue;
                try {
                    closeable.close();
                }
                catch (Exception e) {
                    logger.error("Resource close failed.", (Throwable)e);
                }
            }
            this.parent.detach(this.uid);
            this.resources.clear();
        }
    }

    public void debugDump(int level) {
        StringBuilder sb = new StringBuilder(100);
        for (int i = 0; i < level; ++i) {
            sb.append("    ");
        }
        sb.append("\\--- NDManager(").append(this.uid.substring(24)).append(") resource count: ").append(this.resources.size());
        System.out.println(sb.toString());
        for (Reference<AutoCloseable> ref : this.resources.values()) {
            AutoCloseable c = ref.get();
            if (!(c instanceof MxNDManager)) continue;
            ((MxNDManager)((Object)c)).debugDump(level + 1);
        }
    }

    boolean isOpen() {
        return !this.closed.get();
    }

    private NDArray fill(String opName, Device dev, Shape shape, DataType dataType) {
        MxOpParams params = new MxOpParams();
        if (shape == null) {
            throw new IllegalArgumentException("Shape is required for " + opName.substring(1));
        }
        params.addParam("shape", shape);
        params.setDevice(Device.defaultIfNull((Device)dev, (Device)this.device));
        params.setDataType(dataType);
        return this.invoke(opName, params);
    }

    private static final class SystemManager
    extends MxNDManager {
        SystemManager() {
            super(null, Device.defaultDevice());
        }

        @Override
        public void attach(String resourceId, AutoCloseable resource) {
        }

        @Override
        public void detach(String resourceId) {
        }

        @Override
        public void close() {
        }
    }
}

