/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.engine;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDUtils;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.NDIndexBooleans;
import ai.djl.ndarray.index.NDIndexFullSlice;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.pytorch.engine.PtNDArrayEx;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.pytorch.jni.NativeResource;
import ai.djl.pytorch.jni.Pointer;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class PtNDArray
extends NativeResource
implements NDArray {
    private static final int MAX_SIZE = 100;
    private static final int MAX_DEPTH = 10;
    private static final int MAX_ROWS = 10;
    private static final int MAX_COLUMNS = 20;
    private String name;
    private Device device;
    private DataType dataType;
    private Shape shape;
    private SparseFormat sparseFormat;
    private PtNDManager manager;
    private PtNDArrayEx ptNDArrayEx;

    PtNDArray(PtNDManager manager, Pointer handle, Device device, Shape shape, DataType dataType) {
        this(manager, handle);
        this.device = device;
        if (Arrays.stream(shape.getShape()).anyMatch(s -> s < 0L)) {
            throw new IllegalArgumentException("The shape must be >= 0");
        }
        this.shape = shape;
        this.dataType = dataType;
    }

    PtNDArray(PtNDManager manager, Pointer handle) {
        super(handle);
        this.manager = manager;
        this.ptNDArrayEx = new PtNDArrayEx(this);
    }

    public PtNDManager getManager() {
        return this.manager;
    }

    public String getName() {
        return this.name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public DataType getDataType() {
        if (this.dataType == null) {
            this.dataType = JniUtils.getDataType(this);
        }
        return this.dataType;
    }

    public Device getDevice() {
        if (this.device == null) {
            this.device = JniUtils.getDevice(this);
        }
        return this.device;
    }

    public Shape getShape() {
        if (this.shape == null) {
            this.shape = JniUtils.getShape(this);
        }
        return this.shape;
    }

    public SparseFormat getSparseFormat() {
        if (this.sparseFormat == null) {
            this.sparseFormat = JniUtils.getSparseFormat(this);
        }
        return this.sparseFormat;
    }

    public PtNDArray toDevice(Device device, boolean copy) {
        return JniUtils.to(this, this.getDataType(), device, copy);
    }

    public PtNDArray toType(DataType dataType, boolean copy) {
        return JniUtils.to(this, dataType, this.getDevice(), copy);
    }

    public void attachGradient() {
    }

    public PtNDArray getGradient() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public ByteBuffer toByteBuffer() {
        return JniUtils.getByteBuffer(this);
    }

    public void set(Buffer data) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void set(NDIndex index, NDArray value) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void set(NDIndex index, Number value) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void setScalar(NDIndex index, Number value) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray get(NDIndex index) {
        if (this.isScalar()) {
            return (PtNDArray)this.duplicate();
        }
        List indices = index.getIndices();
        if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
            if (indices.size() != 1) {
                throw new IllegalArgumentException("get() currently didn't support more that one boolean NDArray");
            }
            return this.booleanMask(((NDIndexBooleans)indices.get(0)).getIndex(), 0);
        }
        NDIndexFullSlice fullSlice = index.getAsFullSlice(this.getShape()).orElse(null);
        PtNDArray afterSlice = this;
        if (fullSlice != null) {
            long[] min = fullSlice.getMin();
            long[] max = fullSlice.getMax();
            long[] step = fullSlice.getStep();
            for (int dim = 0; dim < min.length; ++dim) {
                if (step[dim] == 1L && this.getShape().get(dim) == max[dim] - min[dim]) continue;
                afterSlice = JniUtils.slice(afterSlice, dim, min[dim], max[dim], step[dim]);
            }
            return afterSlice.squeeze(fullSlice.getToSqueeze().stream().mapToInt(i -> i).toArray());
        }
        throw new UnsupportedOperationException("get() currently supports all, fixed, and slices indices");
    }

    public void copyTo(NDArray array) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray duplicate() {
        return JniUtils.clone(this);
    }

    public PtNDArray booleanMask(NDArray index, int axis) {
        Shape indexShape = index.getShape();
        if (indexShape.equals((Object)this.getShape())) {
            return JniUtils.booleanMask(this, (PtNDArray)index);
        }
        if (indexShape.equals((Object)this.getShape().slice(axis))) {
            PtNDArray flattedResult = JniUtils.booleanMask(this, (PtNDArray)index);
            Shape remainder = this.getShape().slice(0, axis);
            long selectedSize = flattedResult.getShape().size() / remainder.size();
            return flattedResult.reshape(remainder.addAll(new Shape(new long[]{selectedSize})));
        }
        throw new UnsupportedOperationException("Not supported for shape not broadcastable " + indexShape.toString() + " vs " + this.getShape().toString());
    }

    public PtNDArray zerosLike() {
        return JniUtils.zerosLike(this, this.getDataType(), this.getDevice(), SparseFormat.DENSE);
    }

    public PtNDArray onesLike() {
        return JniUtils.onesLike(this, this.getDataType(), this.getDevice(), SparseFormat.DENSE);
    }

    public boolean contentEquals(Number number) {
        return JniUtils.contentEqual(this, (PtNDArray)this.manager.create(number));
    }

    public boolean contentEquals(NDArray other) {
        if (other == null || !this.shapeEquals(other)) {
            return false;
        }
        if (this.getDataType() != other.getDataType()) {
            return false;
        }
        return JniUtils.contentEqual(this, (PtNDArray)other);
    }

    public PtNDArray eq(Number other) {
        return this.eq(this.manager.create(other));
    }

    public PtNDArray eq(NDArray other) {
        return JniUtils.eq(this, (PtNDArray)other);
    }

    public PtNDArray neq(Number other) {
        return this.neq(this.manager.create(other));
    }

    public PtNDArray neq(NDArray other) {
        return JniUtils.neq(this, (PtNDArray)other);
    }

    public PtNDArray gt(Number other) {
        return this.gt(this.manager.create(other));
    }

    public PtNDArray gt(NDArray other) {
        return JniUtils.gt(this, (PtNDArray)other);
    }

    public PtNDArray gte(Number other) {
        return this.gte(this.manager.create(other));
    }

    public PtNDArray gte(NDArray other) {
        return JniUtils.gte(this, (PtNDArray)other);
    }

    public PtNDArray lt(Number other) {
        return this.lt(this.manager.create(other));
    }

    public PtNDArray lt(NDArray other) {
        return JniUtils.lt(this, (PtNDArray)other);
    }

    public PtNDArray lte(Number other) {
        return this.lte(this.manager.create(other));
    }

    public PtNDArray lte(NDArray other) {
        return JniUtils.lte(this, (PtNDArray)other);
    }

    public PtNDArray add(Number n) {
        return this.add(this.manager.create(n));
    }

    public PtNDArray add(NDArray other) {
        return JniUtils.add(this, (PtNDArray)other);
    }

    public PtNDArray sub(Number n) {
        return this.sub(this.manager.create(n));
    }

    public PtNDArray sub(NDArray other) {
        return JniUtils.sub(this, (PtNDArray)other);
    }

    public PtNDArray mul(Number n) {
        return this.mul(this.manager.create(n));
    }

    public PtNDArray mul(NDArray other) {
        return JniUtils.mul(this, (PtNDArray)other);
    }

    public PtNDArray div(Number n) {
        return this.div(this.manager.create(n));
    }

    public PtNDArray div(NDArray other) {
        return JniUtils.div(this, (PtNDArray)other);
    }

    public PtNDArray mod(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray mod(NDArray other) {
        return JniUtils.remainder(this, (PtNDArray)other);
    }

    public PtNDArray pow(Number n) {
        return this.pow(this.manager.create(n));
    }

    public PtNDArray pow(NDArray other) {
        return JniUtils.pow(this, (PtNDArray)other);
    }

    public PtNDArray addi(Number n) {
        return this.addi(this.manager.create(n));
    }

    public PtNDArray addi(NDArray other) {
        JniUtils.addi(this, (PtNDArray)other);
        return this;
    }

    public PtNDArray subi(Number n) {
        return this.subi(this.manager.create(n));
    }

    public PtNDArray subi(NDArray other) {
        JniUtils.subi(this, (PtNDArray)other);
        return this;
    }

    public PtNDArray muli(Number n) {
        return this.muli(this.manager.create(n));
    }

    public PtNDArray muli(NDArray other) {
        JniUtils.muli(this, (PtNDArray)other);
        return this;
    }

    public PtNDArray divi(Number n) {
        return this.divi(this.manager.create(n));
    }

    public PtNDArray divi(NDArray other) {
        JniUtils.divi(this, (PtNDArray)other);
        return this;
    }

    public PtNDArray modi(Number n) {
        return this.modi(this.manager.create(n));
    }

    public PtNDArray modi(NDArray other) {
        JniUtils.remainderi(this, (PtNDArray)other);
        return this;
    }

    public PtNDArray powi(Number n) {
        return this.powi(this.manager.create(n));
    }

    public PtNDArray powi(NDArray other) {
        JniUtils.powi(this, (PtNDArray)other);
        return this;
    }

    public PtNDArray maximum(Number n) {
        return this.maximum(this.manager.create(n));
    }

    public PtNDArray maximum(NDArray other) {
        if (!other.getDataType().equals((Object)this.getDataType())) {
            throw new IllegalArgumentException("DataType mismatch, expected " + this.getDataType() + " Actual " + other.getDataType());
        }
        return JniUtils.max(this, (PtNDArray)other);
    }

    public PtNDArray minimum(Number n) {
        return this.minimum(this.manager.create(n));
    }

    public PtNDArray minimum(NDArray other) {
        if (!other.getDataType().equals((Object)this.getDataType())) {
            throw new IllegalArgumentException("DataType mismatch, expected " + this.getDataType() + " Actual " + other.getDataType());
        }
        return JniUtils.min(this, (PtNDArray)other);
    }

    public PtNDArray all() {
        return JniUtils.all(this.toType(DataType.BOOLEAN, true));
    }

    public PtNDArray any() {
        return JniUtils.any(this.toType(DataType.BOOLEAN, true));
    }

    public PtNDArray none() {
        return JniUtils.none(this.toType(DataType.BOOLEAN, true));
    }

    public PtNDArray neg() {
        return JniUtils.neg(this);
    }

    public PtNDArray negi() {
        JniUtils.negi(this);
        return this;
    }

    public PtNDArray abs() {
        return JniUtils.abs(this);
    }

    public PtNDArray square() {
        return this.pow(2);
    }

    public NDArray sqrt() {
        return JniUtils.sqrt(this);
    }

    public PtNDArray cbrt() {
        return JniUtils.pow(this, (PtNDArray)this.manager.create(0.3333333333333333));
    }

    public PtNDArray floor() {
        return JniUtils.floor(this);
    }

    public PtNDArray ceil() {
        return JniUtils.ceil(this);
    }

    public PtNDArray round() {
        return JniUtils.round(this);
    }

    public PtNDArray trunc() {
        return JniUtils.trunc(this);
    }

    public PtNDArray exp() {
        return JniUtils.exp(this);
    }

    public PtNDArray log() {
        return JniUtils.log(this);
    }

    public PtNDArray log10() {
        return JniUtils.log10(this);
    }

    public PtNDArray log2() {
        return JniUtils.log2(this);
    }

    public PtNDArray sin() {
        return JniUtils.sin(this);
    }

    public PtNDArray cos() {
        return JniUtils.cos(this);
    }

    public PtNDArray tan() {
        return JniUtils.tan(this);
    }

    public PtNDArray asin() {
        return JniUtils.asin(this);
    }

    public PtNDArray acos() {
        return JniUtils.acos(this);
    }

    public PtNDArray atan() {
        return JniUtils.atan(this);
    }

    public PtNDArray sinh() {
        return JniUtils.sinh(this);
    }

    public PtNDArray cosh() {
        return JniUtils.cosh(this);
    }

    public PtNDArray tanh() {
        return JniUtils.tanh(this);
    }

    public PtNDArray asinh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray acosh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray atanh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray toDegrees() {
        return this.mul(180.0).div(Math.PI);
    }

    public PtNDArray toRadians() {
        return this.mul(Math.PI).div(this.manager.create(180.0));
    }

    public PtNDArray max() {
        return JniUtils.max(this);
    }

    public PtNDArray max(int[] axes, boolean keepDims) {
        if (axes.length > 1) {
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return JniUtils.max(this, axes[0], keepDims);
    }

    public PtNDArray min() {
        return JniUtils.min(this);
    }

    public PtNDArray min(int[] axes, boolean keepDims) {
        if (axes.length > 1) {
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return JniUtils.min(this, axes[0], keepDims);
    }

    public PtNDArray sum() {
        return JniUtils.sum(this);
    }

    public PtNDArray sum(int[] axes, boolean keepDims) {
        return JniUtils.sum(this, Arrays.stream(axes).mapToLong(i -> i).toArray(), keepDims);
    }

    public PtNDArray prod() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray prod(int[] axes, boolean keepDims) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray mean() {
        return JniUtils.mean(this);
    }

    public PtNDArray mean(int[] axes, boolean keepDims) {
        if (axes.length > 1) {
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return JniUtils.mean(this, axes[0], keepDims);
    }

    public PtNDArray trace(int offset, int axis1, int axis2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList split(long sections, int axis) {
        return JniUtils.split(this, sections, (long)axis);
    }

    public NDList split(long[] indices, int axis) {
        ArrayList<Long> ptIndex = new ArrayList<Long>();
        ptIndex.add(indices[0]);
        for (int i2 = 1; i2 < indices.length; ++i2) {
            ptIndex.add(indices[i2] - indices[i2 - 1]);
        }
        ptIndex.add(this.size(axis) - indices[indices.length - 1]);
        return JniUtils.split(this, ptIndex.stream().mapToLong(i -> i).toArray(), (long)axis);
    }

    public PtNDArray flatten() {
        return JniUtils.flatten(this, 0L, -1L);
    }

    public PtNDArray reshape(Shape shape) {
        return JniUtils.reshape(this, shape.getShape());
    }

    public NDArray reshapeLike(NDArray array) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray expandDims(int axis) {
        return JniUtils.unsqueeze(this, axis);
    }

    public PtNDArray squeeze() {
        return JniUtils.squeeze(this);
    }

    public PtNDArray squeeze(int axis) {
        return JniUtils.squeeze(this, axis);
    }

    public PtNDArray squeeze(int[] axes) {
        if (this.isScalar()) {
            if (axes.length > 1 || axes[0] != 0) {
                throw new IllegalArgumentException("axis " + axes[0] + "is out of bounds for array of dimension 0");
            }
            return (PtNDArray)this.duplicate();
        }
        long[] shapeArr = this.getShape().getShape();
        ArrayList<Long> newShape = new ArrayList<Long>();
        Set set = IntStream.of(axes).boxed().collect(Collectors.toCollection(HashSet::new));
        for (int axis : axes) {
            if (shapeArr[axis] == 1L) continue;
            throw new IllegalArgumentException("cannot select an axis to squeeze out which has size not equal to one");
        }
        for (int i2 = 0; i2 < shapeArr.length; ++i2) {
            if (set.contains(i2)) continue;
            newShape.add(shapeArr[i2]);
        }
        return (PtNDArray)this.reshape(newShape.stream().mapToLong(i -> i).toArray());
    }

    public PtNDArray logicalAnd(NDArray other) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray logicalOr(NDArray other) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray logicalXor(NDArray other) {
        return JniUtils.logicalXor(this, (PtNDArray)other);
    }

    public PtNDArray logicalNot() {
        return JniUtils.logicalNot(this);
    }

    public PtNDArray argSort(int axis, boolean ascending) {
        if (!ascending) {
            throw new UnsupportedOperationException("Only support ascending!");
        }
        return JniUtils.argSort(this, axis, false);
    }

    public PtNDArray sort() {
        return this.sort(-1);
    }

    public PtNDArray sort(int axis) {
        return JniUtils.sort(this, axis, false);
    }

    public PtNDArray softmax(int[] axes, float temperature) {
        if ((double)temperature != 1.0) {
            throw new UnsupportedOperationException("PyTorch softmax didn't suuport temperature");
        }
        return JniUtils.softmax(this, axes[0], this.getDataType());
    }

    public PtNDArray logSoftmax(int[] axes, float temperature) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray cumSum() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray cumSum(int axis) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray isInfinite() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray isNaN() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray createMask(NDIndex index) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray createMask(Predicate<Number> predicate) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray tile(long repeats) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray tile(int axis, long repeats) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray tile(long[] repeats) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray tile(Shape desiredShape) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray repeat(long repeats) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray repeat(int axis, long repeats) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray repeat(long[] repeats) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray repeat(Shape desiredShape) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray dot(NDArray other) {
        return JniUtils.matmul(this, (PtNDArray)other);
    }

    public PtNDArray clip(Number min, Number max) {
        return JniUtils.clip(this, min, max);
    }

    public PtNDArray swapAxes(int axis1, int axis2) {
        return JniUtils.transpose(this, axis1, axis2);
    }

    public PtNDArray transpose() {
        int dim = this.getShape().dimension();
        int[] reversedShape = IntStream.range(0, dim).map(i -> dim - i - 1).toArray();
        return this.transpose(reversedShape);
    }

    public PtNDArray transpose(int ... axes) {
        if (this.isScalar() && axes.length > 0) {
            throw new IllegalArgumentException("axes don't match NDArray");
        }
        return JniUtils.permute(this, Arrays.stream(axes).mapToLong(i -> i).toArray());
    }

    public PtNDArray broadcast(Shape shape) {
        return JniUtils.broadcast(this, shape);
    }

    public PtNDArray argMax() {
        if (this.isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMax of an empty NDArray");
        }
        return JniUtils.argMax(this);
    }

    public PtNDArray argMax(int axis) {
        if (this.isEmpty()) {
            Shape newShape = NDUtils.getShapeFromEmptyNDArrayForReductionOp((Shape)this.getShape(), (int)axis);
            return (PtNDArray)this.manager.create(newShape, DataType.INT64);
        }
        return JniUtils.argMax(this, axis, false);
    }

    public PtNDArray argMin() {
        if (this.isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMin of an empty NDArray");
        }
        return JniUtils.argMin(this);
    }

    public PtNDArray argMin(int axis) {
        if (this.isEmpty()) {
            Shape newShape = NDUtils.getShapeFromEmptyNDArrayForReductionOp((Shape)this.getShape(), (int)axis);
            return (PtNDArray)this.manager.create(newShape, DataType.INT64);
        }
        return JniUtils.argMin(this, axis, false);
    }

    public PtNDArray percentile(Number percentile) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray percentile(Number percentile, int[] axes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray median() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray median(int[] axes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray toDense() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray toSparse(SparseFormat fmt) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray nonzero() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArrayEx getNDArrayInternal() {
        return this.ptNDArrayEx;
    }

    public String toString() {
        if (this.isReleased()) {
            return "This array is already closed";
        }
        return this.toDebugString(100, 10, 10, 20);
    }

    public boolean equals(Object obj) {
        if (obj instanceof PtNDArray) {
            return this.contentEquals((PtNDArray)obj);
        }
        return false;
    }

    public int hashCode() {
        return 0;
    }

    @Override
    public void close() {
        Pointer pointer = this.handle.getAndSet(null);
        if (pointer != null) {
            JniUtils.deleteNdArray(pointer);
            this.manager.detach(this.getUid());
            this.manager = null;
        }
    }
}

