/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.engine.rust;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import ai.djl.engine.rust.RsNDArrayEx;
import ai.djl.engine.rust.RsNDManager;
import ai.djl.engine.rust.RustLibrary;
import ai.djl.ndarray.BaseNDManager;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDResource;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.util.NativeResource;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.stream.IntStream;

public class RsNDArray
extends NativeResource<Long>
implements NDArray {
    private String name;
    private Device device;
    private DataType dataType;
    private Shape shape;
    private RsNDManager manager;
    private RsNDArrayEx ndArrayEx;
    private ByteBuffer dataRef;

    public RsNDArray(RsNDManager manager, long handle) {
        this(manager, handle, null, null);
    }

    RsNDArray(RsNDManager manager, long handle, DataType dataType) {
        this(manager, handle, dataType, null);
    }

    public RsNDArray(RsNDManager manager, long handle, DataType dataType, ByteBuffer data) {
        super((Object)handle);
        this.dataType = dataType;
        this.manager = manager;
        this.ndArrayEx = new RsNDArrayEx(this);
        this.dataRef = data;
        manager.attachInternal(this.getUid(), new AutoCloseable[]{this});
        NDScope.register((NDArray)this);
    }

    public RsNDManager getManager() {
        return this.manager;
    }

    public String getName() {
        return this.name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public DataType getDataType() {
        if (this.dataType == null) {
            int type = RustLibrary.getDataType((Long)this.getHandle());
            this.dataType = DataType.values()[type];
        }
        return this.dataType;
    }

    public Device getDevice() {
        if (this.device == null) {
            String deviceType;
            int[] dev = RustLibrary.getDevice((Long)this.getHandle());
            switch (dev[0]) {
                case 0: {
                    deviceType = "cpu";
                    break;
                }
                case 1: {
                    deviceType = "gpu";
                    break;
                }
                case 2: {
                    deviceType = "mps";
                    break;
                }
                default: {
                    throw new EngineException("Unknown device type: " + dev[0]);
                }
            }
            this.device = Device.of((String)deviceType, (int)dev[1]);
        }
        return this.device;
    }

    public Shape getShape() {
        if (this.shape == null) {
            this.shape = new Shape(RustLibrary.getShape((Long)this.getHandle()));
        }
        return this.shape;
    }

    public SparseFormat getSparseFormat() {
        return SparseFormat.DENSE;
    }

    public RsNDArray toDevice(Device device, boolean copy) {
        if (device.equals((Object)this.getDevice()) && !copy) {
            return this;
        }
        String deviceType = device.getDeviceType();
        long newHandle = RustLibrary.toDevice((Long)this.getHandle(), deviceType, device.getDeviceId());
        return this.toArray(newHandle, null, false, true);
    }

    public RsNDArray toType(DataType dataType, boolean copy) {
        if (dataType.equals((Object)this.getDataType()) && !copy) {
            return this;
        }
        if (dataType == DataType.BOOLEAN) {
            long newHandle = RustLibrary.toBoolean((Long)this.getHandle());
            return this.toArray(newHandle, dataType, false, true);
        }
        if (this.dataType == DataType.INT64 && dataType == DataType.FLOAT16 && this.getDevice().isGpu()) {
            throw new UnsupportedOperationException("FP16 to I64 is not supported on GPU.");
        }
        int dType = this.manager.toRustDataType(dataType);
        long newHandle = RustLibrary.toDataType((Long)this.getHandle(), dType);
        return this.toArray(newHandle, dataType, false, true);
    }

    public void setRequiresGradient(boolean requiresGrad) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public RsNDArray getGradient() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public boolean hasGradient() {
        return false;
    }

    public NDArray stopGradient() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public ByteBuffer toByteBuffer(boolean tryDirect) {
        byte[] buf = RustLibrary.toByteArray((Long)this.getHandle());
        ByteBuffer bb = ByteBuffer.wrap(buf);
        bb.order(ByteOrder.nativeOrder());
        return bb;
    }

    public String[] toStringArray(Charset charset) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void set(Buffer buffer) {
        int size = Math.toIntExact(this.size());
        DataType type = this.getDataType();
        BaseNDManager.validateBuffer((Buffer)buffer, (DataType)type, (int)size);
        this.dataRef = null;
        if (buffer.isDirect() && buffer instanceof ByteBuffer) {
            if (!this.getDevice().isGpu()) {
                this.dataRef = (ByteBuffer)buffer;
            }
            this.intern(this.manager.create(buffer, this.getShape(), type).toDevice(this.getDevice(), false));
            return;
        }
        ByteBuffer buf = this.manager.allocateDirect(size * type.getNumOfBytes());
        BaseNDManager.copyBuffer((Buffer)buffer, (ByteBuffer)buf);
        if (!this.getDevice().isGpu()) {
            this.dataRef = buf;
        }
        this.intern(this.manager.create(buf, this.getShape(), type).toDevice(this.getDevice(), false));
    }

    public NDArray gather(NDArray index, int axis) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray gatherNd(NDArray index) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray take(NDManager manager, NDArray index) {
        try (NDScope ignore = new NDScope();){
            long indexHandle = (Long)this.manager.from(index).getHandle();
            long newHandle = RustLibrary.take((Long)this.getHandle(), indexHandle);
            RsNDArray array = new RsNDArray((RsNDManager)manager, newHandle);
            NDScope.unregister((NDArray)array);
            RsNDArray rsNDArray = array;
            return rsNDArray;
        }
    }

    public NDArray put(NDArray index, NDArray value) {
        try (NDScope ignore = new NDScope();){
            long indexHandle = (Long)this.manager.from(index).getHandle();
            long valueHandle = (Long)this.manager.from(value).getHandle();
            RsNDArray rsNDArray = this.toArray(RustLibrary.put((Long)this.getHandle(), indexHandle, valueHandle), true);
            return rsNDArray;
        }
    }

    public NDArray scatter(NDArray index, NDArray value, int axis) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void attach(NDManager manager) {
        this.detach();
        this.manager = (RsNDManager)manager;
        manager.attachInternal(this.getUid(), new AutoCloseable[]{this});
    }

    public void returnResource(NDManager manager) {
        this.detach();
        this.manager = (RsNDManager)manager;
        manager.attachUncappedInternal(this.getUid(), (AutoCloseable)((Object)this));
    }

    public void tempAttach(NDManager manager) {
        RsNDManager original = this.manager;
        this.detach();
        this.manager = (RsNDManager)manager;
        manager.tempAttachInternal((NDManager)original, this.getUid(), (NDResource)this);
    }

    public void detach() {
        this.manager.detachInternal(this.getUid());
        this.manager = RsNDManager.getSystemManager();
    }

    public NDArray duplicate() {
        return this.toArray(RustLibrary.duplicate((Long)this.getHandle()), this.dataType, false, true);
    }

    public RsNDArray booleanMask(NDArray index, int axis) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray sequenceMask(NDArray sequenceLength, float value) {
        throw new UnsupportedOperationException("Not implemented yet");
    }

    public NDArray sequenceMask(NDArray sequenceLength) {
        throw new UnsupportedOperationException("Not implemented yet");
    }

    public boolean contentEquals(Number number) {
        return this.contentEquals(this.manager.create(number));
    }

    public boolean contentEquals(NDArray other) {
        if (other == null || !this.shapeEquals(other)) {
            return false;
        }
        if (this.getDataType() != other.getDataType()) {
            return false;
        }
        return RustLibrary.contentEqual((Long)this.getHandle(), (Long)this.manager.from(other).getHandle());
    }

    public RsNDArray eq(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.eq(number);
            return rsNDArray;
        }
    }

    public RsNDArray eq(NDArray other) {
        try (NDScope ignore = new NDScope();){
            long newHandle = RustLibrary.eq((Long)this.getHandle(), (Long)this.manager.from(other).getHandle());
            RsNDArray rsNDArray = this.toArray(newHandle, DataType.BOOLEAN, true, false);
            return rsNDArray;
        }
    }

    public RsNDArray neq(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.neq(number);
            return rsNDArray;
        }
    }

    public RsNDArray neq(NDArray other) {
        try (NDScope ignore = new NDScope();){
            long newHandle = RustLibrary.neq((Long)this.getHandle(), (Long)this.manager.from(other).getHandle());
            RsNDArray rsNDArray = this.toArray(newHandle, DataType.BOOLEAN, true, false);
            return rsNDArray;
        }
    }

    public RsNDArray gt(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.gt(number);
            return rsNDArray;
        }
    }

    public RsNDArray gt(NDArray other) {
        try (NDScope ignore = new NDScope();){
            long newHandle = RustLibrary.gt((Long)this.getHandle(), (Long)this.manager.from(other).getHandle());
            RsNDArray rsNDArray = this.toArray(newHandle, DataType.BOOLEAN, true, false);
            return rsNDArray;
        }
    }

    public RsNDArray gte(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.gte(number);
            return rsNDArray;
        }
    }

    public RsNDArray gte(NDArray other) {
        try (NDScope ignore = new NDScope();){
            long newHandle = RustLibrary.gte((Long)this.getHandle(), (Long)this.manager.from(other).getHandle());
            RsNDArray rsNDArray = this.toArray(newHandle, DataType.BOOLEAN, true, false);
            return rsNDArray;
        }
    }

    public RsNDArray lt(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.lt(number);
            return rsNDArray;
        }
    }

    public RsNDArray lt(NDArray other) {
        try (NDScope ignore = new NDScope();){
            long newHandle = RustLibrary.lt((Long)this.getHandle(), (Long)this.manager.from(other).getHandle());
            RsNDArray rsNDArray = this.toArray(newHandle, DataType.BOOLEAN, true, false);
            return rsNDArray;
        }
    }

    public RsNDArray lte(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.lte(number);
            return rsNDArray;
        }
    }

    public RsNDArray lte(NDArray other) {
        try (NDScope ignore = new NDScope();){
            long newHandle = RustLibrary.lte((Long)this.getHandle(), (Long)this.manager.from(other).getHandle());
            RsNDArray rsNDArray = this.toArray(newHandle, DataType.BOOLEAN, true, false);
            return rsNDArray;
        }
    }

    public RsNDArray add(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.add(number);
            return rsNDArray;
        }
    }

    public RsNDArray add(NDArray other) {
        try (NDScope ignore = new NDScope();){
            RsNDArray rsNDArray = this.toArray(RustLibrary.add((Long)this.getHandle(), (Long)this.manager.from(other).getHandle()), true);
            return rsNDArray;
        }
    }

    public RsNDArray sub(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.sub(number);
            return rsNDArray;
        }
    }

    public RsNDArray sub(NDArray other) {
        try (NDScope ignore = new NDScope();){
            RsNDArray rsNDArray = this.toArray(RustLibrary.sub((Long)this.getHandle(), (Long)this.manager.from(other).getHandle()), true);
            return rsNDArray;
        }
    }

    public RsNDArray mul(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.mul(number);
            return rsNDArray;
        }
    }

    public RsNDArray mul(NDArray other) {
        try (NDScope ignore = new NDScope();){
            RsNDArray rsNDArray = this.toArray(RustLibrary.mul((Long)this.getHandle(), (Long)this.manager.from(other).getHandle()), true);
            return rsNDArray;
        }
    }

    public RsNDArray div(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.div(number);
            return rsNDArray;
        }
    }

    public RsNDArray div(NDArray other) {
        try (NDScope ignore = new NDScope();){
            RsNDArray rsNDArray = this.toArray(RustLibrary.div((Long)this.getHandle(), (Long)this.manager.from(other).getHandle()), true);
            return rsNDArray;
        }
    }

    public RsNDArray mod(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.mod(number);
            return rsNDArray;
        }
    }

    public RsNDArray mod(NDArray other) {
        try (NDScope ignore = new NDScope();){
            long otherHandle = (Long)this.manager.from(other).getHandle();
            RsNDArray rsNDArray = this.toArray(RustLibrary.remainder((Long)this.getHandle(), otherHandle), true);
            return rsNDArray;
        }
    }

    public RsNDArray pow(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.pow(number);
            return rsNDArray;
        }
    }

    public RsNDArray pow(NDArray other) {
        try (NDScope ignore = new NDScope();){
            RsNDArray rsNDArray = this.toArray(RustLibrary.pow((Long)this.getHandle(), (Long)this.manager.from(other).getHandle()), true);
            return rsNDArray;
        }
    }

    public NDArray xlogy(NDArray other) {
        if (this.isScalar() || other.isScalar()) {
            throw new IllegalArgumentException("scalar is not allowed for xlogy()");
        }
        try (NDScope ignore = new NDScope();){
            RsNDArray rsNDArray = this.toArray(RustLibrary.xlogy((Long)this.getHandle(), (Long)this.manager.from(other).getHandle()), true);
            return rsNDArray;
        }
    }

    public RsNDArray addi(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.addi(number);
            return rsNDArray;
        }
    }

    public RsNDArray addi(NDArray other) {
        this.intern(this.add(other));
        return this;
    }

    public RsNDArray subi(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.subi(number);
            return rsNDArray;
        }
    }

    public RsNDArray subi(NDArray other) {
        this.intern(this.sub(other));
        return this;
    }

    public RsNDArray muli(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.muli(number);
            return rsNDArray;
        }
    }

    public RsNDArray muli(NDArray other) {
        this.intern(this.mul(other));
        return this;
    }

    public RsNDArray divi(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.divi(number);
            return rsNDArray;
        }
    }

    public RsNDArray divi(NDArray other) {
        this.intern(this.div(other));
        return this;
    }

    public RsNDArray modi(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.modi(number);
            return rsNDArray;
        }
    }

    public RsNDArray modi(NDArray other) {
        this.intern(this.mod(other));
        return this;
    }

    public RsNDArray powi(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.powi(number);
            return rsNDArray;
        }
    }

    public RsNDArray powi(NDArray other) {
        this.intern(this.pow(other));
        return this;
    }

    public RsNDArray signi() {
        this.intern(this.sign());
        return this;
    }

    public RsNDArray negi() {
        this.intern(this.neg());
        return this;
    }

    public RsNDArray sign() {
        return this.toArray(RustLibrary.sign((Long)this.getHandle()));
    }

    public RsNDArray maximum(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.maximum(number);
            return rsNDArray;
        }
    }

    public RsNDArray maximum(NDArray other) {
        try (NDScope ignore = new NDScope();){
            RsNDArray rsNDArray = this.toArray(RustLibrary.maximum((Long)this.getHandle(), (Long)this.manager.from(other).getHandle()), true);
            return rsNDArray;
        }
    }

    public RsNDArray minimum(Number n) {
        try (NDArray number = this.manager.create(n);){
            RsNDArray rsNDArray = this.minimum(number);
            return rsNDArray;
        }
    }

    public RsNDArray minimum(NDArray other) {
        try (NDScope ignore = new NDScope();){
            RsNDArray rsNDArray = this.toArray(RustLibrary.minimum((Long)this.getHandle(), (Long)this.manager.from(other).getHandle()), true);
            return rsNDArray;
        }
    }

    public RsNDArray all() {
        NDArray noneZero = this.countNonzero();
        RsNDArray ret = (RsNDArray)this.manager.create(noneZero.getLong(new long[0]) == this.size());
        noneZero.close();
        return ret;
    }

    public RsNDArray any() {
        NDArray noneZero = this.countNonzero();
        RsNDArray ret = (RsNDArray)this.manager.create(noneZero.getLong(new long[0]) > 0L);
        noneZero.close();
        return ret;
    }

    public RsNDArray none() {
        NDArray noneZero = this.countNonzero();
        RsNDArray ret = (RsNDArray)this.manager.create(noneZero.getLong(new long[0]) == 0L);
        noneZero.close();
        return ret;
    }

    public NDArray countNonzero() {
        try (NDScope ignore = new NDScope();){
            RsNDArray rsNDArray = this.toArray(RustLibrary.countNonzero((Long)this.getHandle()), true);
            return rsNDArray;
        }
    }

    public NDArray countNonzero(int axis) {
        try (NDScope ignore = new NDScope();){
            RsNDArray rsNDArray = this.toArray(RustLibrary.countNonzeroWithAxis((Long)this.getHandle(), axis), true);
            return rsNDArray;
        }
    }

    public RsNDArray neg() {
        return this.toArray(RustLibrary.neg((Long)this.getHandle()));
    }

    public RsNDArray abs() {
        return this.toArray(RustLibrary.abs((Long)this.getHandle()));
    }

    public RsNDArray square() {
        return this.toArray(RustLibrary.square((Long)this.getHandle()));
    }

    public NDArray sqrt() {
        return this.toArray(RustLibrary.sqrt((Long)this.getHandle()));
    }

    public RsNDArray cbrt() {
        try (RsNDArray array = (RsNDArray)this.manager.create(0.3333333333333333);){
            RsNDArray rsNDArray = this.toArray(RustLibrary.pow((Long)this.getHandle(), (Long)array.getHandle()), true);
            return rsNDArray;
        }
    }

    public RsNDArray floor() {
        return this.toArray(RustLibrary.floor((Long)this.getHandle()));
    }

    public RsNDArray ceil() {
        return this.toArray(RustLibrary.ceil((Long)this.getHandle()));
    }

    public RsNDArray round() {
        return this.toArray(RustLibrary.round((Long)this.getHandle()));
    }

    public RsNDArray trunc() {
        return this.toArray(RustLibrary.trunc((Long)this.getHandle()));
    }

    public RsNDArray exp() {
        return this.toArray(RustLibrary.exp((Long)this.getHandle()));
    }

    public NDArray gammaln() {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public RsNDArray log() {
        return this.toArray(RustLibrary.log((Long)this.getHandle()));
    }

    public RsNDArray log10() {
        return this.toArray(RustLibrary.log10((Long)this.getHandle()));
    }

    public RsNDArray log2() {
        return this.toArray(RustLibrary.log2((Long)this.getHandle()));
    }

    public RsNDArray sin() {
        return this.toArray(RustLibrary.sin((Long)this.getHandle()));
    }

    public RsNDArray cos() {
        return this.toArray(RustLibrary.cos((Long)this.getHandle()));
    }

    public RsNDArray tan() {
        return this.toArray(RustLibrary.tan((Long)this.getHandle()));
    }

    public RsNDArray asin() {
        return this.toArray(RustLibrary.asin((Long)this.getHandle()));
    }

    public RsNDArray acos() {
        return this.toArray(RustLibrary.acos((Long)this.getHandle()));
    }

    public RsNDArray atan() {
        return this.toArray(RustLibrary.atan((Long)this.getHandle()));
    }

    public RsNDArray atan2(NDArray other) {
        try (NDScope ignore = new NDScope();){
            long otherHandle = (Long)this.manager.from(other).getHandle();
            RsNDArray rsNDArray = this.toArray(RustLibrary.atan2((Long)this.getHandle(), otherHandle), true);
            return rsNDArray;
        }
    }

    public RsNDArray sinh() {
        return this.toArray(RustLibrary.sinh((Long)this.getHandle()));
    }

    public RsNDArray cosh() {
        return this.toArray(RustLibrary.cosh((Long)this.getHandle()));
    }

    public RsNDArray tanh() {
        return this.toArray(RustLibrary.tanh((Long)this.getHandle()));
    }

    public RsNDArray asinh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public RsNDArray acosh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public RsNDArray atanh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public RsNDArray toDegrees() {
        return this.mul(180.0).div(Math.PI);
    }

    public RsNDArray toRadians() {
        return this.mul(Math.PI).div(180.0);
    }

    public RsNDArray max() {
        if (this.isScalar()) {
            return this;
        }
        return this.toArray(RustLibrary.max((Long)this.getHandle()));
    }

    public RsNDArray max(int[] axes, boolean keepDims) {
        if (axes.length > 1) {
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return this.toArray(RustLibrary.maxWithAxis((Long)this.getHandle(), axes[0], keepDims));
    }

    public RsNDArray min() {
        if (this.isScalar()) {
            return this;
        }
        return this.toArray(RustLibrary.min((Long)this.getHandle()));
    }

    public RsNDArray min(int[] axes, boolean keepDims) {
        if (axes.length > 1) {
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return this.toArray(RustLibrary.minWithAxis((Long)this.getHandle(), axes[0], keepDims));
    }

    public RsNDArray sum() {
        if (this.isScalar()) {
            return this;
        }
        return this.toArray(RustLibrary.sum((Long)this.getHandle()));
    }

    public RsNDArray sum(int[] axes, boolean keepDims) {
        return this.toArray(RustLibrary.sumWithAxis((Long)this.getHandle(), axes, keepDims));
    }

    public NDArray cumProd(int axis) {
        return this.toArray(RustLibrary.cumProd((Long)this.getHandle(), axis));
    }

    public NDArray cumProd(int axis, DataType dataType) {
        return this.toArray(RustLibrary.cumProdWithType((Long)this.getHandle(), axis, dataType.ordinal()));
    }

    public RsNDArray prod() {
        return this.toArray(RustLibrary.prod((Long)this.getHandle()));
    }

    public RsNDArray prod(int[] axes, boolean keepDims) {
        if (axes.length > 1) {
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return this.toArray(RustLibrary.cumProdWithAxis((Long)this.getHandle(), axes[0], keepDims));
    }

    public RsNDArray mean() {
        return this.toArray(RustLibrary.mean((Long)this.getHandle()));
    }

    public RsNDArray mean(int[] axes, boolean keepDims) {
        return this.toArray(RustLibrary.meanWithAxis((Long)this.getHandle(), axes, keepDims));
    }

    public RsNDArray normalize(double p, long dim, double eps) {
        return this.toArray(RustLibrary.normalize((Long)this.getHandle(), p, dim, eps));
    }

    public RsNDArray rotate90(int times, int[] axes) {
        if (axes.length != 2) {
            throw new IllegalArgumentException("Axes must be 2");
        }
        return this.toArray(RustLibrary.rot90((Long)this.getHandle(), times, axes));
    }

    public RsNDArray trace(int offset, int axis1, int axis2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList split(long[] indices, int axis) {
        if (indices.length == 0) {
            return new NDList(new NDArray[]{this});
        }
        long lastIndex = this.getShape().get(axis);
        if (indices[indices.length - 1] != lastIndex) {
            long[] tmp = new long[indices.length + 1];
            System.arraycopy(indices, 0, tmp, 0, indices.length);
            tmp[indices.length] = lastIndex;
            indices = tmp;
        }
        return this.toList(RustLibrary.split((Long)this.getHandle(), indices, axis));
    }

    public RsNDArray flatten() {
        return this.toArray(RustLibrary.flatten((Long)this.getHandle()));
    }

    public NDArray flatten(int startDim, int endDim) {
        return this.toArray(RustLibrary.flattenWithDims((Long)this.getHandle(), startDim, endDim));
    }

    public NDArray fft(long length, long axis) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray rfft(long length, long axis) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray ifft(long length, long axis) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray irfft(long length, long axis) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray stft(long nFft, long hopLength, boolean center, NDArray window, boolean normalize, boolean returnComplex) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray fft2(long[] sizes, long[] axes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray pad(Shape padding, double value) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray ifft2(long[] sizes, long[] axes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public RsNDArray reshape(Shape shape) {
        long prod = 1L;
        int neg = -1;
        long[] dims = shape.getShape();
        for (int i = 0; i < dims.length; ++i) {
            if (dims[i] < 0L) {
                if (neg != -1) {
                    throw new IllegalArgumentException("only 1 negative axis is allowed");
                }
                neg = i;
                continue;
            }
            prod *= dims[i];
        }
        if (neg != -1) {
            long total = this.getShape().size();
            if (total % prod != 0L) {
                throw new IllegalArgumentException("unsupported dimensions");
            }
            dims[neg] = total / prod;
        }
        return this.toArray(RustLibrary.reshape((Long)this.getHandle(), shape.getShape()));
    }

    public RsNDArray expandDims(int axis) {
        return this.toArray(RustLibrary.expandDims((Long)this.getHandle(), axis));
    }

    public RsNDArray squeeze(int[] axes) {
        return this.toArray(RustLibrary.squeeze((Long)this.getHandle(), axes));
    }

    public NDList unique(Integer dim, boolean sorted, boolean returnInverse, boolean returnCounts) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public RsNDArray logicalAnd(NDArray other) {
        try (NDScope ignore = new NDScope();){
            long otherHandle = (Long)this.manager.from(other).getHandle();
            RsNDArray rsNDArray = this.toArray(RustLibrary.logicalAnd((Long)this.getHandle(), otherHandle), true);
            return rsNDArray;
        }
    }

    public RsNDArray logicalOr(NDArray other) {
        try (NDScope ignore = new NDScope();){
            long otherHandle = (Long)this.manager.from(other).getHandle();
            RsNDArray rsNDArray = this.toArray(RustLibrary.logicalOr((Long)this.getHandle(), otherHandle), true);
            return rsNDArray;
        }
    }

    public RsNDArray logicalXor(NDArray other) {
        try (NDScope ignore = new NDScope();){
            long otherHandle = (Long)this.manager.from(other).getHandle();
            RsNDArray rsNDArray = this.toArray(RustLibrary.logicalXor((Long)this.getHandle(), otherHandle), true);
            return rsNDArray;
        }
    }

    public RsNDArray logicalNot() {
        return this.toArray(RustLibrary.logicalNot((Long)this.getHandle()));
    }

    public RsNDArray argSort(int axis, boolean ascending) {
        return this.toArray(RustLibrary.argSort((Long)this.getHandle(), axis, ascending));
    }

    public RsNDArray sort() {
        return this.sort(-1);
    }

    public RsNDArray sort(int axis) {
        return this.toArray(RustLibrary.sort((Long)this.getHandle(), axis, false));
    }

    public RsNDArray softmax(int axis) {
        if (this.getShape().isScalar() || this.shape.size() == 0L) {
            return (RsNDArray)this.duplicate();
        }
        return this.toArray(RustLibrary.softmax((Long)this.getHandle(), axis));
    }

    public RsNDArray logSoftmax(int axis) {
        return this.toArray(RustLibrary.logSoftmax((Long)this.getHandle(), axis));
    }

    public RsNDArray cumSum() {
        if (this.isScalar()) {
            return (RsNDArray)this.reshape(new long[]{1L});
        }
        if (this.isEmpty()) {
            return (RsNDArray)this.reshape(new long[]{0L});
        }
        return this.cumSum(0);
    }

    public RsNDArray cumSum(int axis) {
        if (this.getShape().dimension() > 3) {
            throw new UnsupportedOperationException("Only 3 dimensions or less is supported");
        }
        return this.toArray(RustLibrary.cumSum((Long)this.getHandle(), axis));
    }

    public void intern(NDArray replaced) {
        RsNDArray arr = (RsNDArray)replaced;
        Long oldHandle = this.handle.getAndSet(arr.handle.getAndSet(null));
        RustLibrary.deleteTensor(oldHandle);
        arr.close();
    }

    public RsNDArray isInfinite() {
        return this.toArray(RustLibrary.isInf((Long)this.getHandle()));
    }

    public RsNDArray isNaN() {
        return this.toArray(RustLibrary.isNaN((Long)this.getHandle()));
    }

    public RsNDArray tile(long repeats) {
        if (this.isEmpty()) {
            return (RsNDArray)this.duplicate();
        }
        int dim = this.isScalar() ? 1 : this.getShape().dimension();
        long[] repeatsArray = new long[dim];
        Arrays.fill(repeatsArray, repeats);
        return this.tile(repeatsArray);
    }

    public RsNDArray tile(int axis, long repeat) {
        return this.toArray(RustLibrary.tileWithAxis((Long)this.getHandle(), axis, repeat));
    }

    public RsNDArray tile(long[] repeats) {
        return this.toArray(RustLibrary.tile((Long)this.getHandle(), repeats));
    }

    public RsNDArray tile(Shape desiredShape) {
        return this.toArray(RustLibrary.tileWithShape((Long)this.getHandle(), desiredShape.getShape()));
    }

    public RsNDArray repeat(long repeats) {
        if (this.isEmpty()) {
            return (RsNDArray)this.duplicate();
        }
        int dim = this.isScalar() ? 1 : this.getShape().dimension();
        long[] repeatsArray = new long[dim];
        Arrays.fill(repeatsArray, repeats);
        return this.repeat(repeatsArray);
    }

    public RsNDArray repeat(int axis, long repeat) {
        return this.toArray(RustLibrary.repeat((Long)this.getHandle(), repeat, axis));
    }

    public RsNDArray repeat(long[] repeats) {
        RsNDArray result = this;
        for (int dim = 0; dim < repeats.length; ++dim) {
            RsNDArray temp = result;
            result = result.repeat(dim, repeats[dim]);
            if (temp == this) continue;
            temp.close();
        }
        return result;
    }

    public RsNDArray repeat(Shape desiredShape) {
        return this.repeat(this.repeatsToMatchShape(desiredShape));
    }

    private long[] repeatsToMatchShape(Shape desiredShape) {
        Shape curShape = this.getShape();
        int dimension = curShape.dimension();
        if (desiredShape.dimension() > dimension) {
            throw new IllegalArgumentException("The desired shape has too many dimensions");
        }
        if (desiredShape.dimension() < dimension) {
            int additionalDimensions = dimension - desiredShape.dimension();
            desiredShape = curShape.slice(0, additionalDimensions).addAll(desiredShape);
        }
        long[] repeats = new long[dimension];
        for (int i = 0; i < dimension; ++i) {
            if (curShape.get(i) == 0L || desiredShape.get(i) % curShape.get(i) != 0L) {
                throw new IllegalArgumentException("The desired shape is not a multiple of the original shape");
            }
            repeats[i] = Math.round(Math.ceil((double)desiredShape.get(i) / (double)curShape.get(i)));
        }
        return repeats;
    }

    public RsNDArray dot(NDArray other) {
        int otherDim;
        int selfDim = this.getShape().dimension();
        if (selfDim != (otherDim = other.getShape().dimension()) || selfDim > 2) {
            throw new UnsupportedOperationException("Dimension mismatch or dimension is greater than 2.  Dot product is only applied on two 1D vectors. For high dimensions, please use .matMul instead.");
        }
        try (NDScope ignore = new NDScope();){
            RsNDArray rsNDArray = this.toArray(RustLibrary.dot((Long)this.getHandle(), (Long)this.manager.from(other).getHandle()), true);
            return rsNDArray;
        }
    }

    public NDArray matMul(NDArray other) {
        if (this.getShape().dimension() < 2 || this.getShape().dimension() < 2) {
            throw new IllegalArgumentException("only 2d tensors are supported for matMul()");
        }
        try (NDScope ignore = new NDScope();){
            long otherHandle = (Long)this.manager.from(other).getHandle();
            RsNDArray rsNDArray = this.toArray(RustLibrary.matmul((Long)this.getHandle(), otherHandle), true);
            return rsNDArray;
        }
    }

    public NDArray batchMatMul(NDArray other) {
        if (this.getShape().dimension() != 3 || this.getShape().dimension() != 3) {
            throw new IllegalArgumentException("only 3d tensors are allowed for batchMatMul()");
        }
        try (NDScope ignore = new NDScope();){
            long otherHandle = (Long)this.manager.from(other).getHandle();
            RsNDArray rsNDArray = this.toArray(RustLibrary.batchMatMul((Long)this.getHandle(), otherHandle), true);
            return rsNDArray;
        }
    }

    public RsNDArray clip(Number min, Number max) {
        return this.toArray(RustLibrary.clip((Long)this.getHandle(), min.doubleValue(), max.doubleValue()));
    }

    public RsNDArray swapAxes(int axis1, int axis2) {
        return this.toArray(RustLibrary.transpose((Long)this.getHandle(), axis1, axis2));
    }

    public NDArray flip(int ... axes) {
        return this.toArray(RustLibrary.flip((Long)this.getHandle(), axes));
    }

    public RsNDArray transpose() {
        int dim = this.getShape().dimension();
        int[] reversedShape = IntStream.range(0, dim).map(i -> dim - i - 1).toArray();
        return this.transpose(reversedShape);
    }

    public RsNDArray transpose(int ... axes) {
        if (this.isScalar() && axes.length > 0) {
            throw new IllegalArgumentException("axes don't match NDArray");
        }
        return this.toArray(RustLibrary.permute((Long)this.getHandle(), axes));
    }

    public RsNDArray broadcast(Shape shape) {
        return this.toArray(RustLibrary.broadcast((Long)this.getHandle(), shape.getShape()));
    }

    public RsNDArray argMax() {
        if (this.isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMax of an empty NDArray");
        }
        if (this.isScalar()) {
            return (RsNDArray)this.manager.create(0L);
        }
        return this.toArray(RustLibrary.argMax((Long)this.getHandle()));
    }

    public RsNDArray argMax(int axis) {
        if (this.isScalar()) {
            return (RsNDArray)this.manager.create(0L);
        }
        return this.toArray(RustLibrary.argMaxWithAxis((Long)this.getHandle(), axis, false));
    }

    public NDList topK(int k, int axis, boolean largest, boolean sorted) {
        return this.toList(RustLibrary.topK((Long)this.getHandle(), k, axis, largest, sorted));
    }

    public RsNDArray argMin() {
        if (this.isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMin of an empty NDArray");
        }
        if (this.isScalar()) {
            return (RsNDArray)this.manager.create(0L);
        }
        return this.toArray(RustLibrary.argMin((Long)this.getHandle()));
    }

    public RsNDArray argMin(int axis) {
        if (this.isScalar()) {
            return (RsNDArray)this.manager.create(0L);
        }
        return this.toArray(RustLibrary.argMinWithAxis((Long)this.getHandle(), axis, false));
    }

    public RsNDArray percentile(Number percentile) {
        return this.toArray(RustLibrary.percentile((Long)this.getHandle()));
    }

    public RsNDArray percentile(Number percentile, int[] axes) {
        return this.toArray(RustLibrary.percentileWithAxes((Long)this.getHandle(), percentile.doubleValue(), axes));
    }

    public RsNDArray median() {
        return this.median(new int[]{-1});
    }

    public RsNDArray median(int[] axes) {
        if (axes.length != 1) {
            throw new UnsupportedOperationException("Not supporting zero or multi-dimension median");
        }
        NDList result = this.toList(RustLibrary.median((Long)this.getHandle(), axes[0], false));
        ((NDArray)result.get(1)).close();
        return (RsNDArray)((Object)result.get(0));
    }

    public RsNDArray toDense() {
        return (RsNDArray)this.duplicate();
    }

    public RsNDArray toSparse(SparseFormat fmt) {
        throw new UnsupportedOperationException("Not supported");
    }

    public RsNDArray nonzero() {
        return this.toArray(RustLibrary.nonZero((Long)this.getHandle()));
    }

    public RsNDArray erfinv() {
        return this.toArray(RustLibrary.erfinv((Long)this.getHandle()));
    }

    public RsNDArray erf() {
        return this.toArray(RustLibrary.erf((Long)this.getHandle()));
    }

    public RsNDArray inverse() {
        return this.toArray(RustLibrary.inverse((Long)this.getHandle()));
    }

    public NDArray norm(boolean keepDims) {
        return this.toArray(RustLibrary.norm((Long)this.getHandle(), 2, new int[0], keepDims));
    }

    public NDArray norm(int order, int[] axes, boolean keepDims) {
        return this.toArray(RustLibrary.norm((Long)this.getHandle(), order, axes, keepDims));
    }

    public NDArray oneHot(int depth, float onValue, float offValue, DataType dataType) {
        return this.toArray(RustLibrary.oneHot((Long)this.getHandle(), depth, onValue, offValue, dataType.ordinal()));
    }

    public NDArray batchDot(NDArray other) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray complex() {
        return this.toArray(RustLibrary.complex((Long)this.getHandle()));
    }

    public NDArray real() {
        return this.toArray(RustLibrary.real((Long)this.getHandle()));
    }

    public NDArray conj() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public RsNDArrayEx getNDArrayInternal() {
        if (this.ndArrayEx == null) {
            throw new UnsupportedOperationException("NDArray operation is not supported for String tensor");
        }
        return this.ndArrayEx;
    }

    public NDArray diff(int n, int dim) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public String toString() {
        if (this.isReleased()) {
            return "This array is already closed";
        }
        return this.toDebugString();
    }

    public boolean equals(Object obj) {
        if (obj instanceof NDArray) {
            return this.contentEquals((NDArray)obj);
        }
        return false;
    }

    public int hashCode() {
        return 0;
    }

    public void close() {
        this.onClose();
        Long pointer = this.handle.getAndSet(null);
        if (pointer != null && pointer != -1L) {
            RustLibrary.deleteTensor(pointer);
        }
        this.manager.detachInternal(this.getUid());
        this.dataRef = null;
    }

    private RsNDArray toArray(long newHandle) {
        return this.toArray(newHandle, false);
    }

    private RsNDArray toArray(long newHandle, boolean unregister) {
        return this.toArray(newHandle, null, unregister, false);
    }

    private RsNDArray toArray(long newHandle, DataType dataType, boolean unregister, boolean withName) {
        RsNDArray array = new RsNDArray(this.manager, newHandle, dataType);
        if (withName) {
            array.setName(this.getName());
        }
        if (unregister) {
            NDScope.unregister((NDArray)array);
        }
        return array;
    }

    private NDList toList(long[] handles) {
        NDList list = new NDList(handles.length);
        for (long h : handles) {
            list.add((Object)new RsNDArray(this.manager, h));
        }
        return list;
    }
}

