/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.serde.binary;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.Channels;
import java.nio.channels.FileChannel;
import java.nio.channels.WritableByteChannel;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionDescriptor;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BinarySerde {
    private static final Logger log = LoggerFactory.getLogger(BinarySerde.class);

    public static INDArray toArray(ByteBuffer buffer, int offset) {
        return (INDArray)BinarySerde.toArrayAndByteBuffer(buffer, offset).getLeft();
    }

    public static INDArray toArray(ByteBuffer buffer) {
        return BinarySerde.toArray(buffer, 0);
    }

    protected static Pair<INDArray, ByteBuffer> toArrayAndByteBuffer(ByteBuffer buffer, int offset) {
        ByteBuffer byteBuffer = buffer.hasArray() ? ByteBuffer.allocateDirect(buffer.array().length).put(buffer.array()).order(ByteOrder.nativeOrder()) : buffer.order(ByteOrder.nativeOrder());
        byteBuffer.position(offset);
        int rank = byteBuffer.getInt();
        if (rank < 0) {
            throw new IllegalStateException("Found negative integer. Corrupt serialization?");
        }
        int shapeBufferLength = Shape.shapeInfoLength(rank);
        DataBuffer shapeBuff = Nd4j.createBufferDetached(new int[shapeBufferLength]);
        DataType type = DataType.values()[byteBuffer.getInt()];
        for (int i = 0; i < shapeBufferLength; ++i) {
            shapeBuff.put((long)i, byteBuffer.getLong());
        }
        if (type != DataType.COMPRESSED) {
            ByteBuffer slice = byteBuffer.slice();
            if (Shape.length(shapeBuff) > Integer.MAX_VALUE) {
                throw new ND4JArraySizeException();
            }
            DataBuffer buff = Nd4j.createBuffer(slice, type, (int)Shape.length(shapeBuff));
            int position = byteBuffer.position() + buff.getElementSize() * (int)buff.length();
            byteBuffer.position(position);
            INDArray arr = Nd4j.createArrayFromShapeBuffer(buff.dup(), shapeBuff.dup());
            return Pair.of((Object)arr, (Object)byteBuffer);
        }
        CompressionDescriptor compressionDescriptor = CompressionDescriptor.fromByteBuffer(byteBuffer);
        ByteBuffer slice = byteBuffer.slice();
        BytePointer byteBufferPointer = new BytePointer(slice);
        CompressedDataBuffer compressedDataBuffer = new CompressedDataBuffer((Pointer)byteBufferPointer, compressionDescriptor);
        INDArray arr = Nd4j.createArrayFromShapeBuffer(compressedDataBuffer.dup(), shapeBuff.dup());
        int compressLength = (int)compressionDescriptor.getCompressedLength();
        byteBuffer.position(byteBuffer.position() + compressLength);
        return Pair.of((Object)arr, (Object)byteBuffer);
    }

    public static ByteBuffer toByteBuffer(INDArray arr) {
        if (arr.isView()) {
            arr = arr.dup();
        }
        if (!arr.isCompressed()) {
            ByteBuffer b3 = ByteBuffer.allocateDirect(BinarySerde.byteBufferSizeFor(arr)).order(ByteOrder.nativeOrder());
            BinarySerde.doByteBufferPutUnCompressed(arr, b3, true);
            return b3;
        }
        ByteBuffer b3 = ByteBuffer.allocateDirect(BinarySerde.byteBufferSizeFor(arr)).order(ByteOrder.nativeOrder());
        BinarySerde.doByteBufferPutCompressed(arr, b3, true);
        return b3;
    }

    public static int byteBufferSizeFor(INDArray arr) {
        if (!arr.isCompressed()) {
            ByteBuffer buffer = arr.data().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
            ByteBuffer shapeBuffer = arr.shapeInfoDataBuffer().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
            int twoInts = 8;
            return twoInts + buffer.limit() + shapeBuffer.limit();
        }
        CompressedDataBuffer compressedDataBuffer = (CompressedDataBuffer)arr.data();
        CompressionDescriptor descriptor = compressedDataBuffer.getCompressionDescriptor();
        ByteBuffer codecByteBuffer = descriptor.toByteBuffer();
        ByteBuffer buffer = arr.data().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
        ByteBuffer shapeBuffer = arr.shapeInfoDataBuffer().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
        int twoInts = 8;
        return twoInts + buffer.limit() + shapeBuffer.limit() + codecByteBuffer.limit();
    }

    public static void doByteBufferPutUnCompressed(INDArray arr, ByteBuffer allocated, boolean rewind) {
        Nd4j.getExecutioner().commit();
        Nd4j.getAffinityManager().ensureLocation(arr, AffinityManager.Location.HOST);
        ByteBuffer buffer = arr.data().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
        ByteBuffer shapeBuffer = arr.shapeInfoDataBuffer().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
        allocated.putInt(arr.rank());
        allocated.putInt(arr.data().dataType().ordinal());
        allocated.put(shapeBuffer);
        allocated.put(buffer);
        if (rewind) {
            ((Buffer)allocated).rewind();
        }
    }

    public static void doByteBufferPutCompressed(INDArray arr, ByteBuffer allocated, boolean rewind) {
        CompressedDataBuffer compressedDataBuffer = (CompressedDataBuffer)arr.data();
        CompressionDescriptor descriptor = compressedDataBuffer.getCompressionDescriptor();
        ByteBuffer codecByteBuffer = descriptor.toByteBuffer();
        ByteBuffer buffer = arr.data().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
        ByteBuffer shapeBuffer = arr.shapeInfoDataBuffer().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
        allocated.putInt(arr.rank());
        allocated.putInt(arr.data().dataType().ordinal());
        allocated.put(shapeBuffer);
        allocated.put(codecByteBuffer);
        allocated.put(buffer);
        if (rewind) {
            ((Buffer)allocated).rewind();
        }
    }

    public static void writeArrayToOutputStream(INDArray arr, OutputStream outputStream) {
        ByteBuffer buffer = BinarySerde.toByteBuffer(arr);
        try (WritableByteChannel channel = Channels.newChannel(outputStream);){
            channel.write(buffer);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static void writeArrayToDisk(INDArray arr, File toWrite) throws IOException {
        try (FileOutputStream os = new FileOutputStream(toWrite);){
            FileChannel channel = os.getChannel();
            ByteBuffer buffer = BinarySerde.toByteBuffer(arr);
            channel.write(buffer);
        }
    }

    public static INDArray readFromDisk(File readFrom) throws IOException {
        try (FileInputStream os = new FileInputStream(readFrom);){
            FileChannel channel = os.getChannel();
            ByteBuffer buffer = ByteBuffer.allocateDirect((int)readFrom.length());
            channel.read(buffer);
            INDArray iNDArray = BinarySerde.toArray(buffer);
            return iNDArray;
        }
    }

    public static DataBuffer readShapeFromDisk(File readFrom) throws IOException {
        try (FileInputStream os = new FileInputStream(readFrom);){
            FileChannel channel = os.getChannel();
            int len = (int)Math.min(536L, readFrom.length());
            ByteBuffer buffer = ByteBuffer.allocateDirect(len);
            channel.read(buffer);
            ByteBuffer byteBuffer = buffer.order(ByteOrder.nativeOrder());
            ((Buffer)buffer).position(0);
            int rank = byteBuffer.getInt();
            long[] result = new long[Shape.shapeInfoLength(rank)];
            result[0] = rank;
            ((Buffer)byteBuffer).position(16);
            for (int e = 1; e < Shape.shapeInfoLength(rank); ++e) {
                result[e] = byteBuffer.getLong();
            }
            DataBuffer dataBuffer = Nd4j.getDataBufferFactory().createLong(result);
            return dataBuffer;
        }
    }
}

