/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.nativeblas;

import java.io.File;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.javacpp.indexer.LongRawIndexer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.BaseNDArrayFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.memory.MemcpyDirection;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;

public abstract class BaseNativeNDArrayFactory
extends BaseNDArrayFactory {
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();

    public BaseNativeNDArrayFactory(DataBuffer.Type dtype, Character order) {
        super(dtype, order);
    }

    public BaseNativeNDArrayFactory(DataBuffer.Type dtype, char order) {
        super(dtype, order);
    }

    public BaseNativeNDArrayFactory() {
    }

    public Pointer convertToNumpy(INDArray array) {
        LongPointer size = new LongPointer(1L);
        Pointer header = NativeOpsHolder.getInstance().getDeviceNativeOps().numpyHeaderForNd4j(array.data().pointer(), array.shapeInfoDataBuffer().pointer(), array.data().getElementSize(), size);
        header.capacity(size.get());
        header.position(0L);
        char[] magic = new char[]{'\\', 'x', '9', '3', 'N', 'U', 'M', 'P', 'Y', '1', '0'};
        BytePointer magicPointer = new BytePointer(new String(magic).getBytes());
        BytePointer bytePointer = new BytePointer(magicPointer.capacity() + (long)((int)(size.get() + (long)array.data().getElementSize() * array.data().length())));
        BytePointer headerCast = new BytePointer(header);
        int pos = 0;
        Pointer.memcpy((Pointer)bytePointer, (Pointer)magicPointer, (long)magicPointer.capacity());
        pos = (int)((long)pos + (magicPointer.capacity() - 1L));
        bytePointer.position((long)pos);
        Pointer.memcpy((Pointer)bytePointer, (Pointer)headerCast, (long)headerCast.capacity());
        pos = (int)((long)pos + (headerCast.capacity() - 1L));
        bytePointer.position((long)pos);
        Pointer.memcpy((Pointer)bytePointer, (Pointer)array.data().pointer(), (long)((long)array.data().getElementSize() * array.data().length()));
        bytePointer.position(0L);
        return bytePointer;
    }

    public INDArray createFromNpyPointer(Pointer pointer) {
        FloatPointer dPointer;
        Pointer dataPointer = this.nativeOps.dataPointForNumpy(pointer);
        int dataBufferElementSize = this.nativeOps.elementSizeForNpyArray(pointer);
        DataBuffer data = null;
        Pointer shapeBufferPointer = this.nativeOps.shapeBufferForNumpy(pointer);
        int length = this.nativeOps.lengthForShapeBufferPointer(shapeBufferPointer);
        shapeBufferPointer.capacity((long)(8 * length));
        shapeBufferPointer.limit((long)(8 * length));
        shapeBufferPointer.position(0L);
        LongPointer intPointer = new LongPointer(shapeBufferPointer);
        LongPointer newPointer = new LongPointer((long)length);
        long perfD = PerformanceTracker.getInstance().helperStartTransaction();
        Pointer.memcpy((Pointer)newPointer, (Pointer)intPointer, (long)shapeBufferPointer.limit());
        PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, shapeBufferPointer.limit(), MemcpyDirection.HOST_TO_HOST);
        DataBuffer shapeBuffer = Nd4j.createBuffer((Pointer)newPointer, (DataBuffer.Type)DataBuffer.Type.LONG, (long)length, (Indexer)LongRawIndexer.create((LongPointer)newPointer));
        dataPointer.position(0L);
        dataPointer.limit((long)dataBufferElementSize * Shape.length((DataBuffer)shapeBuffer));
        dataPointer.capacity((long)dataBufferElementSize * Shape.length((DataBuffer)shapeBuffer));
        if (dataBufferElementSize == 4) {
            dPointer = new FloatPointer(dataPointer.limit() / (long)dataBufferElementSize);
            long perfX = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy((Pointer)dPointer, (Pointer)dataPointer, (long)dataPointer.limit());
            PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST);
            data = Nd4j.createBuffer((Pointer)dPointer, (DataBuffer.Type)DataBuffer.Type.FLOAT, (long)Shape.length((DataBuffer)shapeBuffer), (Indexer)FloatIndexer.create((FloatPointer)dPointer));
        } else if (dataBufferElementSize == 8) {
            dPointer = new DoublePointer(dataPointer.limit() / (long)dataBufferElementSize);
            long perfX = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy((Pointer)dPointer, (Pointer)dataPointer, (long)dataPointer.limit());
            PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST);
            data = Nd4j.createBuffer((Pointer)dPointer, (DataBuffer.Type)DataBuffer.Type.DOUBLE, (long)Shape.length((DataBuffer)shapeBuffer), (Indexer)DoubleIndexer.create((DoublePointer)dPointer));
        }
        INDArray ret = Nd4j.create(data, (long[])Shape.shape((DataBuffer)shapeBuffer), (long[])Shape.strideArr((DataBuffer)shapeBuffer), (long)0L, (char)Shape.order((DataBuffer)shapeBuffer));
        return ret;
    }

    public INDArray createFromNpyHeaderPointer(Pointer pointer) {
        FloatPointer dPointer;
        Pointer dataPointer = this.nativeOps.dataPointForNumpyHeader(pointer);
        int dataBufferElementSize = this.nativeOps.elementSizeForNpyArrayHeader(pointer);
        DataBuffer data = null;
        Pointer shapeBufferPointer = this.nativeOps.shapeBufferForNumpyHeader(pointer);
        int length = this.nativeOps.lengthForShapeBufferPointer(shapeBufferPointer);
        shapeBufferPointer.capacity((long)(8 * length));
        shapeBufferPointer.limit((long)(8 * length));
        shapeBufferPointer.position(0L);
        LongPointer intPointer = new LongPointer(shapeBufferPointer);
        LongPointer newPointer = new LongPointer((long)length);
        long perfD = PerformanceTracker.getInstance().helperStartTransaction();
        Pointer.memcpy((Pointer)newPointer, (Pointer)intPointer, (long)shapeBufferPointer.limit());
        PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, shapeBufferPointer.limit(), MemcpyDirection.HOST_TO_HOST);
        DataBuffer shapeBuffer = Nd4j.createBuffer((Pointer)newPointer, (DataBuffer.Type)DataBuffer.Type.LONG, (long)length, (Indexer)LongRawIndexer.create((LongPointer)newPointer));
        dataPointer.position(0L);
        dataPointer.limit((long)dataBufferElementSize * Shape.length((DataBuffer)shapeBuffer));
        dataPointer.capacity((long)dataBufferElementSize * Shape.length((DataBuffer)shapeBuffer));
        if (dataBufferElementSize == 4) {
            dPointer = new FloatPointer(dataPointer.limit() / (long)dataBufferElementSize);
            long perfX = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy((Pointer)dPointer, (Pointer)dataPointer, (long)dataPointer.limit());
            PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST);
            data = Nd4j.createBuffer((Pointer)dPointer, (DataBuffer.Type)DataBuffer.Type.FLOAT, (long)Shape.length((DataBuffer)shapeBuffer), (Indexer)FloatIndexer.create((FloatPointer)dPointer));
        } else if (dataBufferElementSize == 8) {
            dPointer = new DoublePointer(dataPointer.limit() / (long)dataBufferElementSize);
            long perfX = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy((Pointer)dPointer, (Pointer)dataPointer, (long)dataPointer.limit());
            PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST);
            data = Nd4j.createBuffer((Pointer)dPointer, (DataBuffer.Type)DataBuffer.Type.DOUBLE, (long)Shape.length((DataBuffer)shapeBuffer), (Indexer)DoubleIndexer.create((DoublePointer)dPointer));
        }
        INDArray ret = Nd4j.create(data, (long[])Shape.shape((DataBuffer)shapeBuffer), (long[])Shape.strideArr((DataBuffer)shapeBuffer), (long)0L, (char)Shape.order((DataBuffer)shapeBuffer));
        return ret;
    }

    public INDArray createFromNpyFile(File file) {
        byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8"));
        ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder());
        directBuffer.put(pathBytes);
        directBuffer.rewind();
        directBuffer.position(0);
        Pointer pointer = this.nativeOps.numpyFromFile(new BytePointer(directBuffer));
        INDArray result = this.createFromNpyPointer(pointer);
        this.nativeOps.releaseNumpy(pointer);
        return result;
    }
}

