/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.shape;

import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import lombok.NonNull;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.loop.coordinatefunction.CoordinateFunction;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.api.shape.options.ArrayType;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.util.ArrayUtil;

public class Shape {
    private Shape() {
    }

    public static long[] getMaxShape(INDArray ... inputs) {
        if (inputs == null) {
            return null;
        }
        if (inputs.length < 2) {
            return inputs[0].shape();
        }
        long[] currMax = inputs[0].shape();
        for (int i = 1; i < inputs.length; ++i) {
            if (inputs[i] == null || (long)ArrayUtil.prod((long[])currMax) >= inputs[i].length()) continue;
            currMax = inputs[i].shape();
        }
        return currMax;
    }

    public static boolean shapeIsScalar(int[] shape) {
        return shape.length == 0 || ArrayUtil.prodLong((int[])shape) == 1L;
    }

    public static boolean shapeIsScalar(long[] shape) {
        return shape.length == 0 || ArrayUtil.prodLong((long[])shape) == 1L;
    }

    public static boolean isPlaceholderShape(int[] shape) {
        if (shape == null) {
            return true;
        }
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] >= 0) continue;
            return true;
        }
        return false;
    }

    public static boolean isPlaceholderShape(long[] shape) {
        if (shape == null) {
            return true;
        }
        if (shape.length == 1 && shape[0] == Long.MIN_VALUE) {
            return false;
        }
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] >= 0L) continue;
            return true;
        }
        return false;
    }

    public static int[] getBroadcastDimensions(int[] left, int[] right) {
        if (Arrays.equals(left, right)) {
            return null;
        }
        int n = Math.min(left.length, right.length);
        ArrayList<Integer> dims = new ArrayList<Integer>();
        int leftIdx = left.length - 1;
        int rightIdx = right.length - 1;
        for (int i = n - 1; i >= 0; --i) {
            if (left[leftIdx] != right[rightIdx] && right[rightIdx] == 1 || left[leftIdx] == 1) {
                dims.add(i);
            } else if (left[leftIdx] != right[rightIdx]) {
                throw new IllegalArgumentException("Unable to broadcast dimension " + i + " due to shape mismatch. Right shape must be 1. Left array shape: " + Arrays.toString(left) + ", right array shape: " + Arrays.toString(right));
            }
            --leftIdx;
            --rightIdx;
        }
        Collections.reverse(dims);
        return Ints.toArray(dims);
    }

    public static int[] getBroadcastDimensions(long[] left, long[] right) {
        if (Arrays.equals(left, right)) {
            return null;
        }
        int n = Math.min(left.length, right.length);
        ArrayList<Integer> dims = new ArrayList<Integer>();
        int leftIdx = left.length - 1;
        int rightIdx = right.length - 1;
        for (int i = n - 1; i >= 0; --i) {
            if (left[leftIdx] != right[rightIdx] && right[rightIdx] == 1L || left[leftIdx] == 1L) {
                dims.add(i);
            } else if (left[leftIdx] != right[rightIdx]) {
                throw new IllegalArgumentException("Unable to broadcast dimension " + i + " due to shape mismatch. Right shape must be 1. Left array shape: " + Arrays.toString(left) + ", right array shape: " + Arrays.toString(right));
            }
            --leftIdx;
            --rightIdx;
        }
        Collections.reverse(dims);
        return Ints.toArray(dims);
    }

    public static int[] broadcastOutputShape(int[] left, int[] right) {
        Shape.assertBroadcastable(left, right);
        if (Arrays.equals(left, right)) {
            return left;
        }
        int n = Math.max(left.length, right.length);
        ArrayList<Integer> dims = new ArrayList<Integer>();
        int leftIdx = left.length - 1;
        int rightIdx = right.length - 1;
        for (int i = n - 1; i >= 0; --i) {
            if (leftIdx < 0) {
                dims.add(right[rightIdx]);
            } else if (rightIdx < 0) {
                dims.add(left[leftIdx]);
            } else if (left[leftIdx] != right[rightIdx] && right[rightIdx] == 1 || left[leftIdx] == 1) {
                dims.add(Math.max(left[leftIdx], right[rightIdx]));
            } else if (left[leftIdx] == right[rightIdx]) {
                dims.add(left[leftIdx]);
            } else {
                throw new IllegalArgumentException("Unable to broadcast dimension " + i + " due to shape mismatch. Right shape must be 1.");
            }
            --leftIdx;
            --rightIdx;
        }
        Collections.reverse(dims);
        return Ints.toArray(dims);
    }

    public static long[] broadcastOutputShape(long[] left, long[] right) {
        Shape.assertBroadcastable(left, right);
        if (Arrays.equals(left, right)) {
            return left;
        }
        int n = Math.max(left.length, right.length);
        ArrayList<Long> dims = new ArrayList<Long>();
        int leftIdx = left.length - 1;
        int rightIdx = right.length - 1;
        for (int i = n - 1; i >= 0; --i) {
            if (leftIdx < 0) {
                dims.add(right[rightIdx]);
            } else if (rightIdx < 0) {
                dims.add(left[leftIdx]);
            } else if (left[leftIdx] != right[rightIdx] && right[rightIdx] == 1L || left[leftIdx] == 1L) {
                dims.add(Math.max(left[leftIdx], right[rightIdx]));
            } else if (left[leftIdx] == right[rightIdx]) {
                dims.add(left[leftIdx]);
            } else {
                throw new IllegalArgumentException("Unable to broadcast dimension " + i + " due to shape mismatch. Right shape must be 1.");
            }
            --leftIdx;
            --rightIdx;
        }
        Collections.reverse(dims);
        return Longs.toArray(dims);
    }

    public static int[] resolveNegativeShapeIfNeccessary(int[] newShape, int[] shape) {
        int i;
        int numberNegativesOnes = 0;
        for (i = 0; i < shape.length; ++i) {
            if (shape[i] >= 0) continue;
            if (numberNegativesOnes >= 1) {
                throw new IllegalArgumentException("Only one dimension can be negative ones");
            }
            ++numberNegativesOnes;
            int shapeLength = 1;
            for (int j = 0; j < shape.length; ++j) {
                if (shape[j] < 1) continue;
                shapeLength *= shape[j];
            }
            int realShape = Math.abs(ArrayUtil.prod((int[])newShape) / shapeLength);
            int[] thisNewShape = new int[shape.length];
            for (int j = 0; j < shape.length; ++j) {
                thisNewShape[j] = i != j ? shape[j] : realShape;
            }
            shape = thisNewShape;
            break;
        }
        for (i = 0; i < shape.length; ++i) {
            if (shape[i] != 0) continue;
            shape[i] = 1;
        }
        return shape;
    }

    public static boolean isWholeArray(int[] shape, int ... dimension) {
        return Shape.isWholeArray(shape.length, dimension);
    }

    public static boolean isWholeArray(long[] shape, int ... dimension) {
        return Shape.isWholeArray(shape.length, dimension);
    }

    public static boolean isWholeArray(int rank, int ... dimension) {
        return rank == 0 || dimension == null || dimension.length == 0 || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE || dimension.length == rank;
    }

    public static long[] getReducedShape(int[] wholeShape, int[] dimensions) {
        if (Shape.isWholeArray(wholeShape, dimensions)) {
            return new long[0];
        }
        if (dimensions.length == 1 && wholeShape.length == 2) {
            long[] ret = new long[2];
            if (dimensions[0] == 1) {
                ret[0] = wholeShape[0];
                ret[1] = 1L;
            } else if (dimensions[0] == 0) {
                ret[0] = 1L;
                ret[1] = wholeShape[1];
            }
            return ret;
        }
        return ArrayUtil.toLongArray((int[])ArrayUtil.removeIndex((int[])wholeShape, (int[])dimensions));
    }

    public static long[] getReducedShape(long[] wholeShape, int[] dimensions) {
        if (Shape.isWholeArray(wholeShape, dimensions)) {
            return new long[0];
        }
        if (dimensions.length == 1 && wholeShape.length == 2) {
            long[] ret = new long[2];
            if (dimensions[0] == 1) {
                ret[0] = wholeShape[0];
                ret[1] = 1L;
            } else if (dimensions[0] == 0) {
                ret[0] = 1L;
                ret[1] = wholeShape[1];
            }
            return ret;
        }
        return ArrayUtil.removeIndex((long[])wholeShape, (int[])dimensions);
    }

    public static long[] getReducedShape(int[] wholeShape, int[] dimensions, boolean keepDims, boolean newFormat) {
        dimensions = Shape.normalizeAxis(wholeShape.length, dimensions);
        if (!keepDims) {
            if (!newFormat) {
                return Shape.getReducedShape(wholeShape, dimensions);
            }
            if (Shape.isWholeArray(wholeShape, dimensions)) {
                return new long[0];
            }
            if (dimensions.length == 1 && wholeShape.length == 2) {
                long[] ret = new long[1];
                if (dimensions[0] == 1) {
                    ret[0] = wholeShape[0];
                } else if (dimensions[0] == 0) {
                    ret[0] = wholeShape[1];
                }
                return ret;
            }
            return ArrayUtil.toLongArray((int[])ArrayUtil.removeIndex((int[])wholeShape, (int[])dimensions));
        }
        if (Shape.isWholeArray(wholeShape, dimensions)) {
            long[] result = new long[wholeShape.length];
            Arrays.fill(result, 1L);
            return result;
        }
        long[] result = ArrayUtil.toLongArray((int[])Arrays.copyOf(wholeShape, wholeShape.length));
        for (int dim : dimensions) {
            result[dim] = 1L;
        }
        return result;
    }

    public static long[] getReducedShape(long[] wholeShape, int[] dimensions, boolean keepDims) {
        return Shape.getReducedShape(wholeShape, dimensions, keepDims, true);
    }

    public static long[] getReducedShape(long[] wholeShape, int[] dimensions, boolean keepDims, boolean newFormat) {
        dimensions = Shape.normalizeAxis(wholeShape.length, dimensions);
        if (!keepDims) {
            if (!newFormat) {
                return Shape.getReducedShape(wholeShape, dimensions);
            }
            if (Shape.isWholeArray(wholeShape, dimensions)) {
                return new long[0];
            }
            if (dimensions.length == 1 && wholeShape.length == 2) {
                long[] ret = new long[1];
                if (dimensions[0] == 1) {
                    ret[0] = wholeShape[0];
                } else if (dimensions[0] == 0) {
                    ret[0] = wholeShape[1];
                }
                return ret;
            }
            return ArrayUtil.removeIndex((long[])wholeShape, (int[])dimensions);
        }
        if (Shape.isWholeArray(wholeShape, dimensions)) {
            long[] result = new long[wholeShape.length];
            Arrays.fill(result, 1L);
            return result;
        }
        long[] result = Arrays.copyOf(wholeShape, wholeShape.length);
        for (int dim : dimensions) {
            result[dim] = 1L;
        }
        return result;
    }

    public static int[] getMatrixMultiplyShape(int[] left, int[] right) {
        int i;
        if (Shape.shapeIsScalar(left)) {
            return right;
        }
        if (Shape.shapeIsScalar(right)) {
            return left;
        }
        if (left.length != 2 && right.length != 2) {
            throw new IllegalArgumentException("Illegal shapes for matrix multiply. Must be of length 2. Left shape: " + Arrays.toString(left) + ", right shape: " + Arrays.toString(right));
        }
        for (i = 0; i < left.length; ++i) {
            if (left[i] >= 1) continue;
            throw new ND4JIllegalStateException("Left shape contained value < 0 at index " + i + " - left shape " + Arrays.toString(left));
        }
        for (i = 0; i < right.length; ++i) {
            if (right[i] >= 1) continue;
            throw new ND4JIllegalStateException("Right shape contained value < 0 at index " + i + " - right shape " + Arrays.toString(right));
        }
        if (left.length > 1 && left[1] != right[0]) {
            throw new IllegalArgumentException("Columns of left not equal to rows of right: left shape " + Arrays.toString(left) + ", right shape " + Arrays.toString(right));
        }
        if (left.length < right.length && left[0] == right[0]) {
            return new int[]{1, right[1]};
        }
        int[] shape = new int[]{left[0], right[1]};
        return shape;
    }

    public static long[] getMatrixMultiplyShape(long[] left, long[] right) {
        int i;
        if (Shape.shapeIsScalar(left)) {
            return right;
        }
        if (Shape.shapeIsScalar(right)) {
            return left;
        }
        if (left.length != 2 && right.length != 2 && left.length != 3 && right.length != 3) {
            throw new IllegalArgumentException("Illegal shapes for matrix multiply. Must be both of length 2 or bothof length 3 (batch-wise matrix multiply). Left shape: " + Arrays.toString(left) + ", right shape: " + Arrays.toString(right));
        }
        for (i = 0; i < left.length; ++i) {
            if (left[i] >= 1L) continue;
            throw new ND4JIllegalStateException("Left shape contained value < 0 at index " + i + " - left shape " + Arrays.toString(left));
        }
        for (i = 0; i < right.length; ++i) {
            if (right[i] >= 1L) continue;
            throw new ND4JIllegalStateException("Right shape contained value < 0 at index " + i + " - right shape " + Arrays.toString(right));
        }
        if (left.length == 2 && left[1] != right[0] || left.length == 3 && left[2] != right[1]) {
            throw new IllegalArgumentException("Columns of left not equal to rows of right: left shape " + Arrays.toString(left) + ", right shape " + Arrays.toString(right));
        }
        if (left.length < right.length && left[0] == right[0]) {
            return new long[]{1L, right[1]};
        }
        if (left.length == 3 && left[0] != right[0]) {
            throw new IllegalArgumentException("For batch matrix multiplication the leading dimension of both argumentshas to match, got left leading dimension" + left[0] + "and right " + right[0]);
        }
        long[] shape = left.length == 2 ? new long[]{left[0], right[1]} : new long[]{left[0], left[1], right[2]};
        return shape;
    }

    public static INDArray toOffsetZero(INDArray arr) {
        if (arr.offset() < 1L && arr.data().length() == arr.length() && (arr.ordering() == 'f' && arr.stride(-1) != 1 || arr.ordering() == 'c' && arr.stride(0) != 1)) {
            return arr;
        }
        if (arr.isRowVector()) {
            INDArray ret = Nd4j.create(arr.shape());
            int i = 0;
            while ((long)i < ret.length()) {
                ret.putScalar((long)i, arr.getDouble((long)i));
                ++i;
            }
            return ret;
        }
        INDArray ret = Nd4j.create(arr.shape(), arr.ordering());
        ret.assign(arr);
        return ret;
    }

    public static INDArray toOffsetZeroCopy(INDArray arr) {
        return Shape.toOffsetZeroCopyHelper(arr, Nd4j.order().charValue(), false);
    }

    public static INDArray toOffsetZeroCopy(INDArray arr, char order) {
        return Shape.toOffsetZeroCopyHelper(arr, order, false);
    }

    public static INDArray toOffsetZeroCopyAnyOrder(INDArray arr) {
        return Shape.toOffsetZeroCopyHelper(arr, Nd4j.order().charValue(), true);
    }

    private static INDArray toOffsetZeroCopyHelper(INDArray arr, char order, boolean anyOrder) {
        char outOrder;
        if (arr.isEmpty()) {
            return arr;
        }
        char c = outOrder = anyOrder ? arr.ordering() : order;
        if (outOrder == 'a') {
            outOrder = Nd4j.order().charValue();
        }
        INDArray z = Nd4j.createUninitialized(arr.dataType(), arr.shape(), outOrder);
        z.assign(arr);
        return z;
    }

    public static double getDouble(INDArray arr, int[] indices) {
        long offset = Shape.getOffset(arr.shapeInfo(), ArrayUtil.toLongArray((int[])indices));
        return arr.data().getDouble(offset);
    }

    public static double getDouble(INDArray arr, long ... indices) {
        long offset = Shape.getOffset(arr.shapeInfo(), indices);
        return arr.data().getDouble(offset);
    }

    public static long getLong(INDArray arr, long ... indices) {
        long offset = Shape.getOffset(arr.shapeInfo(), indices);
        return arr.data().getLong(offset);
    }

    public static void iterate(INDArray arr, CoordinateFunction coordinateFunction) {
        Shape.iterate(0, arr.rank(), arr.shape(), new long[arr.rank()], coordinateFunction);
    }

    public static void iterate(INDArray arr, INDArray arr2, CoordinateFunction coordinateFunction) {
        Shape.iterate(0, arr.rank(), arr.shape(), new long[arr.rank()], 0, arr2.rank(), arr2.shape(), new long[arr2.rank()], coordinateFunction);
    }

    public static void iterate(int dimension, int n, int[] size, int[] res, int dimension2, int n2, int[] size2, int[] res2, CoordinateFunction func) {
        if (dimension >= n || dimension2 >= n2) {
            func.process(ArrayUtil.toLongArray((int[])res), ArrayUtil.toLongArray((int[])res2));
            return;
        }
        if (size2.length != size.length) {
            if (dimension >= size.length) {
                return;
            }
            for (int i = 0; i < size[dimension] && dimension2 < size2.length; ++i) {
                int j = 0;
                while (j < size2[dimension2]) {
                    res[dimension] = i;
                    res2[dimension2] = j++;
                    Shape.iterate(dimension + 1, n, size, res, dimension2 + 1, n2, size2, res2, func);
                }
            }
        } else {
            if (dimension >= size.length) {
                return;
            }
            for (int i = 0; i < size[dimension]; ++i) {
                int j = 0;
                while (j < size2[dimension2] && dimension2 < size2.length) {
                    res[dimension] = i;
                    res2[dimension2] = j++;
                    Shape.iterate(dimension + 1, n, size, res, dimension2 + 1, n2, size2, res2, func);
                }
            }
        }
    }

    public static void iterate(int dimension, int n, long[] size, long[] res, int dimension2, int n2, long[] size2, long[] res2, CoordinateFunction func) {
        if (dimension >= n || dimension2 >= n2) {
            func.process(res, res2);
            return;
        }
        if (size2.length != size.length) {
            if (dimension >= size.length) {
                return;
            }
            int i = 0;
            while ((long)i < size[dimension] && dimension2 < size2.length) {
                int j = 0;
                while ((long)j < size2[dimension2]) {
                    res[dimension] = i;
                    res2[dimension2] = j;
                    Shape.iterate(dimension + 1, n, size, res, dimension2 + 1, n2, size2, res2, func);
                    ++j;
                }
                ++i;
            }
        } else {
            if (dimension >= size.length) {
                return;
            }
            int i = 0;
            while ((long)i < size[dimension]) {
                int j = 0;
                while ((long)j < size2[dimension2] && dimension2 < size2.length) {
                    res[dimension] = i;
                    res2[dimension2] = j;
                    Shape.iterate(dimension + 1, n, size, res, dimension2 + 1, n2, size2, res2, func);
                    ++j;
                }
                ++i;
            }
        }
    }

    public static void iterate(int dimension, int n, int[] size, int[] res, CoordinateFunction func) {
        if (dimension >= n) {
            func.process(new long[][]{ArrayUtil.toLongArray((int[])res)});
            return;
        }
        int i = 0;
        while (i < size[dimension]) {
            res[dimension] = i++;
            Shape.iterate(dimension + 1, n, ArrayUtil.toLongArray((int[])size), ArrayUtil.toLongArray((int[])res), func);
        }
    }

    public static void iterate(int dimension, int n, long[] size, long[] res, CoordinateFunction func) {
        if (dimension >= n) {
            func.process(new long[][]{res});
            return;
        }
        int i = 0;
        while ((long)i < size[dimension]) {
            res[dimension] = i;
            Shape.iterate(dimension + 1, n, size, res, func);
            ++i;
        }
    }

    public static long getOffset(long baseOffset, int[] shape, int[] stride, int ... indices) {
        if (shape.length != stride.length || indices.length != shape.length) {
            throw new IllegalArgumentException("Indexes, shape, and stride must be the same length");
        }
        long offset = baseOffset;
        for (int i = 0; i < shape.length; ++i) {
            if (indices[i] >= shape[i]) {
                throw new IllegalArgumentException(String.format("J: Index [%d] must not be >= shape[%d]=%d.", i, i, shape[i]));
            }
            if (shape[i] == 1) continue;
            offset += (long)(indices[i] * stride[i]);
        }
        return offset;
    }

    public static long getOffset(IntBuffer shapeInformation, int[] indices) {
        return Shape.getOffset(shapeInformation, ArrayUtil.toLongArray((int[])indices));
    }

    public static long getOffset(LongBuffer shapeInformation, int[] indices) {
        return Shape.getOffset(shapeInformation, ArrayUtil.toLongArray((int[])indices));
    }

    public static long getOffset(LongBuffer shapeInformation, long ... indices) {
        int rank = Shape.rank(shapeInformation);
        Preconditions.checkState((indices.length == rank ? 1 : 0) != 0, (String)"Number of indices (got %s) must be same as array rank (%s) - indices %s", (Object)indices.length, (Object)rank, (Object)indices);
        long offset = 0L;
        for (int i = 0; i < rank; ++i) {
            int size_dimi = (int)Shape.size(shapeInformation, i);
            if (size_dimi == 1) continue;
            offset += indices[i] * Shape.stride(shapeInformation, i);
        }
        return offset;
    }

    public static long getOffset(IntBuffer shapeInformation, long ... indices) {
        int rank = Shape.rank(shapeInformation);
        if (indices.length != rank) {
            throw new IllegalArgumentException("Indexes must be same length as array rank");
        }
        long offset = 0L;
        for (int i = 0; i < rank; ++i) {
            int size_dimi = Shape.size(shapeInformation, i);
            if (size_dimi == 1) continue;
            offset += indices[i] * (long)Shape.stride(shapeInformation, i);
        }
        return offset;
    }

    public static long getOffset(DataBuffer shapeInformation, int[] indices) {
        return Shape.getOffset(shapeInformation, ArrayUtil.toLongArray((int[])indices));
    }

    public static long getOffset(DataBuffer shapeInformation, long ... indices) {
        int rank = Shape.rank(shapeInformation);
        if (indices.length != rank) {
            throw new IllegalArgumentException("Indexes must be same length as array rank");
        }
        long offset = 0L;
        for (int i = 0; i < rank; ++i) {
            int size_dimi = Shape.size(shapeInformation, i);
            if (indices[i] > (long)size_dimi) {
                throw new IllegalArgumentException(String.format("J: Index [%d] must not be >= shape[%d]=%d.", i, i, size_dimi));
            }
            if (size_dimi == 1) continue;
            offset += indices[i] * (long)Shape.stride(shapeInformation, i);
        }
        return offset;
    }

    public static long getOffset(int[] shapeInformation, int ... indices) {
        int rank = Shape.rank(shapeInformation);
        long offset = 0L;
        for (int i = 0; i < Math.min(rank, indices.length); ++i) {
            int size_dimi = Shape.size(shapeInformation, i);
            if (indices[i] > size_dimi) {
                throw new IllegalArgumentException(String.format("J: Index [%d] must not be >= shape[%d]=%d.", i, i, size_dimi));
            }
            if (size_dimi == 1) continue;
            offset += (long)(indices[i] * Shape.stride(shapeInformation, i));
        }
        return offset;
    }

    public static long getOffset(long[] shapeInformation, int ... indices) {
        int rank = Shape.rank(shapeInformation);
        long offset = 0L;
        for (int i = 0; i < Math.min(rank, indices.length); ++i) {
            long size_dimi = Shape.size(shapeInformation, i);
            if ((long)indices[i] > size_dimi) {
                throw new IllegalArgumentException(String.format("J: Index [%d] must not be >= shape[%d]=%d.", i, i, size_dimi));
            }
            if (size_dimi == 1L) continue;
            offset += (long)indices[i] * Shape.stride(shapeInformation, i);
        }
        return offset;
    }

    public static long getOffset(long[] shapeInformation, long ... indices) {
        int rank = Shape.rank(shapeInformation);
        long offset = 0L;
        for (int i = 0; i < Math.min(rank, indices.length); ++i) {
            long size_dimi = Shape.size(shapeInformation, i);
            if (indices[i] > size_dimi) {
                throw new IllegalArgumentException(String.format("J: Index [%d] must not be >= shape[%d]=%d.", i, i, size_dimi));
            }
            if (size_dimi == 1L) continue;
            offset += indices[i] * Shape.stride(shapeInformation, i);
        }
        return offset;
    }

    public static long getOffset(DataBuffer shapeInformation, int row, int col) {
        int rank = Shape.rank(shapeInformation);
        if (rank != 2) {
            throw new IllegalArgumentException("Cannot use this getOffset method on arrays of rank != 2 (rank is: " + rank + ")");
        }
        return Shape.getOffsetUnsafe(shapeInformation, row, col);
    }

    public static long getOffsetUnsafe(DataBuffer shapeInformation, int row, int col) {
        long offset = 0L;
        int size_0 = Shape.sizeUnsafe(shapeInformation, 0);
        int size_1 = Shape.sizeUnsafe(shapeInformation, 1);
        if (row >= size_0 || col >= size_1) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + row + "," + col + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(row * Shape.strideUnsafe(shapeInformation, 0, 2));
        }
        if (size_1 != 1) {
            offset += (long)(col * Shape.strideUnsafe(shapeInformation, 1, 2));
        }
        return offset;
    }

    public static long getOffsetUnsafe(int[] shapeInformation, int row, int col) {
        long offset = 0L;
        int size_0 = Shape.sizeUnsafe(shapeInformation, 0);
        int size_1 = Shape.sizeUnsafe(shapeInformation, 1);
        if (row >= size_0 || col >= size_1 && !Shape.isVector(Shape.shape(shapeInformation)) && !Shape.shapeIsScalar(Shape.shape(shapeInformation))) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + row + "," + col + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(row * Shape.strideUnsafe(shapeInformation, 0, 2));
        }
        if (size_1 != 1) {
            offset += (long)(col * Shape.strideUnsafe(shapeInformation, 1, 2));
        }
        return offset;
    }

    public static long getOffsetUnsafe(long[] shapeInformation, long row, long col) {
        long offset = 0L;
        long size_0 = Shape.sizeUnsafe(shapeInformation, 0);
        long size_1 = Shape.sizeUnsafe(shapeInformation, 1);
        if (row >= size_0 || col >= size_1 && !Shape.isVector(Shape.shape(shapeInformation)) && !Shape.shapeIsScalar(Shape.shape(shapeInformation))) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + row + "," + col + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1L) {
            offset += row * Shape.strideUnsafe(shapeInformation, 0, 2);
        }
        if (size_1 != 1L) {
            offset += col * Shape.strideUnsafe(shapeInformation, 1, 2);
        }
        return offset;
    }

    public static long getOffset(IntBuffer shapeInformation, int row, int col) {
        int rank = Shape.rank(shapeInformation);
        if (rank != 2) {
            throw new IllegalArgumentException("Cannot use this getOffset method on arrays of rank != 2 (rank is: " + rank + ")");
        }
        long offset = 0L;
        int size_0 = Shape.size(shapeInformation, 0);
        int size_1 = Shape.size(shapeInformation, 1);
        if (row >= size_0 || col >= size_1) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + row + "," + col + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(row * Shape.stride(shapeInformation, 0));
        }
        if (size_1 != 1) {
            offset += (long)(col * Shape.stride(shapeInformation, 1));
        }
        return offset;
    }

    public static long getOffset(IntBuffer shapeInformation, int dim0, int dim1, int dim2) {
        int rank = Shape.rank(shapeInformation);
        if (rank != 3) {
            throw new IllegalArgumentException("Cannot use this getOffset method on arrays of rank != 3 (rank is: " + rank + ")");
        }
        long offset = 0L;
        int size_0 = Shape.size(shapeInformation, 0);
        int size_1 = Shape.size(shapeInformation, 1);
        int size_2 = Shape.size(shapeInformation, 2);
        if (dim0 >= size_0 || dim1 >= size_1 || dim2 >= size_2) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + dim0 + "," + dim1 + "," + dim2 + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(dim0 * Shape.stride(shapeInformation, 0));
        }
        if (size_1 != 1) {
            offset += (long)(dim1 * Shape.stride(shapeInformation, 1));
        }
        if (size_2 != 1) {
            offset += (long)(dim2 * Shape.stride(shapeInformation, 2));
        }
        return offset;
    }

    public static long getOffset(DataBuffer shapeInformation, int dim0, int dim1, int dim2) {
        int rank = Shape.rank(shapeInformation);
        if (rank != 3) {
            throw new IllegalArgumentException("Cannot use this getOffset method on arrays of rank != 3 (rank is: " + rank + ")");
        }
        return Shape.getOffsetUnsafe(shapeInformation, dim0, dim1, dim2);
    }

    public static long getOffsetUnsafe(DataBuffer shapeInformation, int dim0, int dim1, int dim2) {
        long offset = 0L;
        int size_0 = Shape.sizeUnsafe(shapeInformation, 0);
        int size_1 = Shape.sizeUnsafe(shapeInformation, 1);
        int size_2 = Shape.sizeUnsafe(shapeInformation, 2);
        if (dim0 >= size_0 || dim1 >= size_1 || dim2 >= size_2) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + dim0 + "," + dim1 + "," + dim2 + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(dim0 * Shape.strideUnsafe(shapeInformation, 0, 3));
        }
        if (size_1 != 1) {
            offset += (long)(dim1 * Shape.strideUnsafe(shapeInformation, 1, 3));
        }
        if (size_2 != 1) {
            offset += (long)(dim2 * Shape.strideUnsafe(shapeInformation, 2, 3));
        }
        return offset;
    }

    public static long getOffsetUnsafe(int[] shapeInformation, int dim0, int dim1, int dim2) {
        int offset = 0;
        int size_0 = Shape.sizeUnsafe(shapeInformation, 0);
        int size_1 = Shape.sizeUnsafe(shapeInformation, 1);
        int size_2 = Shape.sizeUnsafe(shapeInformation, 2);
        if (dim0 >= size_0 || dim1 >= size_1 || dim2 >= size_2) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + dim0 + "," + dim1 + "," + dim2 + "] from a " + Arrays.toString(shapeInformation) + " NDArray");
        }
        if (size_0 != 1) {
            offset += dim0 * Shape.strideUnsafe(shapeInformation, 0, 3);
        }
        if (size_1 != 1) {
            offset += dim1 * Shape.strideUnsafe(shapeInformation, 1, 3);
        }
        if (size_2 != 1) {
            offset += dim2 * Shape.strideUnsafe(shapeInformation, 2, 3);
        }
        return offset;
    }

    public static long getOffset(IntBuffer shapeInformation, int dim0, int dim1, int dim2, int dim3) {
        int rank = Shape.rank(shapeInformation);
        if (rank != 4) {
            throw new IllegalArgumentException("Cannot use this getOffset method on arrays of rank != 4 (rank is: " + rank + ")");
        }
        long offset = 0L;
        int size_0 = Shape.size(shapeInformation, 0);
        int size_1 = Shape.size(shapeInformation, 1);
        int size_2 = Shape.size(shapeInformation, 2);
        int size_3 = Shape.size(shapeInformation, 3);
        if (dim0 >= size_0 || dim1 >= size_1 || dim2 >= size_2 || dim3 >= size_3) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + dim0 + "," + dim1 + "," + dim2 + "," + dim3 + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(dim0 * Shape.stride(shapeInformation, 0));
        }
        if (size_1 != 1) {
            offset += (long)(dim1 * Shape.stride(shapeInformation, 1));
        }
        if (size_2 != 1) {
            offset += (long)(dim2 * Shape.stride(shapeInformation, 2));
        }
        if (size_3 != 1) {
            offset += (long)(dim3 * Shape.stride(shapeInformation, 3));
        }
        return offset;
    }

    public static long getOffset(DataBuffer shapeInformation, int dim0, int dim1, int dim2, int dim3) {
        int rank = Shape.rank(shapeInformation);
        if (rank != 4) {
            throw new IllegalArgumentException("Cannot use this getOffset method on arrays of rank != 4 (rank is: " + rank + ")");
        }
        return Shape.getOffsetUnsafe(shapeInformation, dim0, dim1, dim2, dim3);
    }

    public static long getOffsetUnsafe(DataBuffer shapeInformation, int dim0, int dim1, int dim2, int dim3) {
        long offset = 0L;
        int size_0 = Shape.sizeUnsafe(shapeInformation, 0);
        int size_1 = Shape.sizeUnsafe(shapeInformation, 1);
        int size_2 = Shape.sizeUnsafe(shapeInformation, 2);
        int size_3 = Shape.sizeUnsafe(shapeInformation, 3);
        if (dim0 >= size_0 || dim1 >= size_1 || dim2 >= size_2 || dim3 >= size_3) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + dim0 + "," + dim1 + "," + dim2 + "," + dim3 + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(dim0 * Shape.strideUnsafe(shapeInformation, 0, 4));
        }
        if (size_1 != 1) {
            offset += (long)(dim1 * Shape.strideUnsafe(shapeInformation, 1, 4));
        }
        if (size_2 != 1) {
            offset += (long)(dim2 * Shape.strideUnsafe(shapeInformation, 2, 4));
        }
        if (size_3 != 1) {
            offset += (long)(dim3 * Shape.strideUnsafe(shapeInformation, 3, 4));
        }
        return offset;
    }

    public static long getOffsetUnsafe(int[] shapeInformation, int dim0, int dim1, int dim2, int dim3) {
        long offset = 0L;
        int size_0 = Shape.sizeUnsafe(shapeInformation, 0);
        int size_1 = Shape.sizeUnsafe(shapeInformation, 1);
        int size_2 = Shape.sizeUnsafe(shapeInformation, 2);
        int size_3 = Shape.sizeUnsafe(shapeInformation, 3);
        if (dim0 >= size_0 || dim1 >= size_1 || dim2 >= size_2 || dim3 >= size_3) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + dim0 + "," + dim1 + "," + dim2 + "," + dim3 + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(dim0 * Shape.strideUnsafe(shapeInformation, 0, 4));
        }
        if (size_1 != 1) {
            offset += (long)(dim1 * Shape.strideUnsafe(shapeInformation, 1, 4));
        }
        if (size_2 != 1) {
            offset += (long)(dim2 * Shape.strideUnsafe(shapeInformation, 2, 4));
        }
        if (size_3 != 1) {
            offset += (long)(dim3 * Shape.strideUnsafe(shapeInformation, 3, 4));
        }
        return offset;
    }

    public static long getOffsetUnsafe(long[] shapeInformation, long dim0, long dim1, long dim2, long dim3) {
        long offset = 0L;
        long size_0 = Shape.sizeUnsafe(shapeInformation, 0);
        long size_1 = Shape.sizeUnsafe(shapeInformation, 1);
        long size_2 = Shape.sizeUnsafe(shapeInformation, 2);
        long size_3 = Shape.sizeUnsafe(shapeInformation, 3);
        if (dim0 >= size_0 || dim1 >= size_1 || dim2 >= size_2 || dim3 >= size_3) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + dim0 + "," + dim1 + "," + dim2 + "," + dim3 + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1L) {
            offset += dim0 * Shape.strideUnsafe(shapeInformation, 0, 4);
        }
        if (size_1 != 1L) {
            offset += dim1 * Shape.strideUnsafe(shapeInformation, 1, 4);
        }
        if (size_2 != 1L) {
            offset += dim2 * Shape.strideUnsafe(shapeInformation, 2, 4);
        }
        if (size_3 != 1L) {
            offset += dim3 * Shape.strideUnsafe(shapeInformation, 3, 4);
        }
        return offset;
    }

    public static int[] sizeForAxes(int[] axes, int[] shape) {
        int[] ret = new int[shape.length];
        for (int i = 0; i < axes.length; ++i) {
            ret[i] = shape[axes[i]];
        }
        return ret;
    }

    public static boolean isVector(IntBuffer shapeInfo) {
        int rank = Shape.rank(shapeInfo);
        if (rank > 2 || rank < 1) {
            return false;
        }
        int len = Shape.length(shapeInfo);
        IntBuffer shape = Shape.shapeOf(shapeInfo);
        return shape.get(0) == len || shape.get(1) == len;
    }

    public static boolean isVector(DataBuffer shapeInfo) {
        int rank = Shape.rank(shapeInfo);
        if (rank > 2 || rank < 1) {
            return false;
        }
        long len = Shape.length(shapeInfo);
        DataBuffer shape = Shape.shapeOf(shapeInfo);
        return (long)shape.getInt(0L) == len || (long)shape.getInt(1L) == len;
    }

    public static boolean isVector(int[] shape) {
        if (shape.length > 2 || shape.length < 1) {
            return false;
        }
        long len = ArrayUtil.prodLong((int[])shape);
        return (long)shape[0] == len || (long)shape[1] == len;
    }

    public static boolean isVector(long[] shape) {
        if (shape.length > 2 || shape.length < 1) {
            return false;
        }
        long len = ArrayUtil.prodLong((long[])shape);
        return shape[0] == len || shape[1] == len;
    }

    public static boolean isMatrix(IntBuffer shapeInfo) {
        int rank = Shape.rank(shapeInfo);
        if (rank != 2) {
            return false;
        }
        return !Shape.isVector(shapeInfo);
    }

    public static boolean isMatrix(DataBuffer shapeInfo) {
        int rank = Shape.rank(shapeInfo);
        if (rank != 2) {
            return false;
        }
        return !Shape.isVector(shapeInfo);
    }

    public static boolean isMatrix(int[] shape) {
        if (shape.length != 2) {
            return false;
        }
        return !Shape.isVector(shape);
    }

    public static boolean isMatrix(long[] shape) {
        if (shape.length != 2) {
            return false;
        }
        return !Shape.isVector(shape);
    }

    public static int[] squeeze(int[] shape) {
        if (Shape.isColumnVectorShape(shape)) {
            return shape;
        }
        ArrayList<Integer> ret = new ArrayList<Integer>();
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] == 1) continue;
            ret.add(shape[i]);
        }
        return ArrayUtil.toArray(ret);
    }

    public static long[] squeeze(long[] shape) {
        if (Shape.isColumnVectorShape(shape)) {
            return shape;
        }
        ArrayList<Long> ret = new ArrayList<Long>();
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] == 1L) continue;
            ret.add(shape[i]);
        }
        return ArrayUtil.toArrayLong(ret);
    }

    public static boolean shapeEqualWithSqueeze(long[] shape1, long[] shape2) {
        if (shape1 == null) {
            return shape2 == null;
        }
        if (shape2 == null) {
            return false;
        }
        if (shape1.length == 0 && shape2.length == 0) {
            return true;
        }
        int pos1 = 0;
        int pos2 = 0;
        while (pos1 < shape1.length && pos2 < shape2.length) {
            if (shape1[pos1] == 1L) {
                ++pos1;
                continue;
            }
            if (shape2[pos2] == 1L) {
                ++pos2;
                continue;
            }
            if (shape1[pos1] != shape2[pos2]) {
                return false;
            }
            ++pos1;
            ++pos2;
        }
        while (pos1 < shape1.length && shape1[pos1] == 1L) {
            ++pos1;
        }
        while (pos2 < shape2.length && shape2[pos2] == 1L) {
            ++pos2;
        }
        return pos1 == shape1.length && pos2 == shape2.length;
    }

    public static boolean shapeEquals(int[] shape1, int[] shape2) {
        if (Shape.isColumnVectorShape(shape1) && Shape.isColumnVectorShape(shape2)) {
            return Arrays.equals(shape1, shape2);
        }
        if (Shape.isRowVectorShape(shape1) && Shape.isRowVectorShape(shape2)) {
            int[] shape1Comp = Shape.squeeze(shape1);
            int[] shape2Comp = Shape.squeeze(shape2);
            return Arrays.equals(shape1Comp, shape2Comp);
        }
        if (shape1.length == 0 || shape2.length == 0) {
            if (shape1.length == 0 && Shape.shapeIsScalar(shape2)) {
                return true;
            }
            if (shape2.length == 0 && Shape.shapeIsScalar(shape1)) {
                return true;
            }
        }
        return Shape.scalarEquals(shape1 = Shape.squeeze(shape1), shape2 = Shape.squeeze(shape2)) || Arrays.equals(shape1, shape2);
    }

    public static boolean shapeEquals(long[] shape1, long[] shape2) {
        if (Shape.isColumnVectorShape(shape1) && Shape.isColumnVectorShape(shape2)) {
            return Arrays.equals(shape1, shape2);
        }
        if (Shape.isRowVectorShape(shape1) && Shape.isRowVectorShape(shape2)) {
            long[] shape1Comp = Shape.squeeze(shape1);
            long[] shape2Comp = Shape.squeeze(shape2);
            return Arrays.equals(shape1Comp, shape2Comp);
        }
        if (shape1.length == 0 || shape2.length == 0) {
            if (shape1.length == 0 && Shape.shapeIsScalar(shape2)) {
                return true;
            }
            if (shape2.length == 0 && Shape.shapeIsScalar(shape1)) {
                return true;
            }
        }
        return Shape.scalarEquals(shape1 = Shape.squeeze(shape1), shape2 = Shape.squeeze(shape2)) || Arrays.equals(shape1, shape2);
    }

    public static boolean scalarEquals(int[] shape1, int[] shape2) {
        if (shape1.length == 0 && shape2.length == 1 && shape2[0] == 1) {
            return true;
        }
        return shape2.length == 0 && shape1.length == 1 && shape1[0] == 1;
    }

    public static boolean scalarEquals(long[] shape1, long[] shape2) {
        if (shape1.length == 0 && shape2.length == 1 && shape2[0] == 1L) {
            return true;
        }
        return shape2.length == 0 && shape1.length == 1 && shape1[0] == 1L;
    }

    public static boolean isRowVectorShape(DataBuffer shapeInfo) {
        int rank = Shape.rank(shapeInfo);
        DataBuffer shape = Shape.shapeOf(shapeInfo);
        return rank == 2 && shape.getInt(0L) == 1 || rank == 1;
    }

    public static boolean isRowVectorShape(IntBuffer shapeInfo) {
        int rank = Shape.rank(shapeInfo);
        IntBuffer shape = Shape.shapeOf(shapeInfo);
        return rank == 2 && shape.get(0) == 1 || rank == 1;
    }

    public static boolean isRowVectorShape(int[] shape) {
        return shape.length == 2 && shape[0] == 1 || shape.length == 1;
    }

    public static boolean isRowVectorShape(long[] shape) {
        return shape.length == 2 && shape[0] == 1L || shape.length == 1;
    }

    public static boolean isColumnVectorShape(int[] shape) {
        return shape.length == 2 && shape[1] == 1;
    }

    public static boolean isColumnVectorShape(long[] shape) {
        return shape.length == 2 && shape[1] == 1L;
    }

    public static int[] ensureAtMinRowVector(int ... shape) {
        if (shape.length >= 2) {
            return shape;
        }
        return new int[]{1, shape[0]};
    }

    public static long getTADLength(int[] shape, int ... dimensions) {
        int tadLength = 1;
        for (int i = 0; i < dimensions.length; ++i) {
            tadLength *= shape[dimensions[i]];
        }
        return tadLength;
    }

    public static long getTADLength(long[] shape, int ... dimensions) {
        int tadLength = 1;
        for (int i = 0; i < dimensions.length; ++i) {
            tadLength = (int)((long)tadLength * shape[dimensions[i]]);
        }
        return tadLength;
    }

    public static int elementWiseStride(int[] shape, int[] stride, boolean isFOrder) {
        int nk;
        int ni;
        int oi;
        if (shape.length == 0 && stride.length == 0) {
            return 1;
        }
        if (shape.length == 1 && stride.length == 1) {
            return 1;
        }
        int[] olddims = ArrayUtil.copy((int[])shape);
        int[] oldstrides = ArrayUtil.copy((int[])stride);
        long[] newStrides = new long[stride.length];
        int oldnd = 0;
        int newShapeRank = 2;
        long[] newShape = new long[shape.length];
        newShape[0] = 1L;
        newShape[1] = ArrayUtil.prodLong((int[])shape);
        for (oi = 0; oi < shape.length; ++oi) {
            if (shape[oi] == 1) continue;
            olddims[oldnd] = shape[oi];
            oldstrides[oldnd] = stride[oi];
            ++oldnd;
        }
        long np = 1L;
        for (ni = 0; ni < newShapeRank; ++ni) {
            np *= newShape[ni];
        }
        long op = 1L;
        for (oi = 0; oi < oldnd; ++oi) {
            op *= (long)olddims[oi];
        }
        if (np != op) {
            return 0;
        }
        if (np == 0L) {
            return 0;
        }
        oi = 0;
        int oj = 1;
        ni = 0;
        int nj = 1;
        while (ni < newShapeRank && oi < oldnd) {
            np = newShape[ni];
            op = olddims[oi];
            while (np != op) {
                if (np < op) {
                    np *= newShape[nj++];
                    continue;
                }
                op *= (long)olddims[oj++];
            }
            for (int ok = oi; ok < oj - 1; ++ok) {
                if (!(isFOrder ? oldstrides[ok + 1] != olddims[ok] * oldstrides[ok] : oldstrides[ok] != olddims[ok + 1] * oldstrides[ok + 1])) continue;
                return 0;
            }
            if (isFOrder) {
                newStrides[ni] = oldstrides[oi];
                for (nk = ni + 1; nk < nj; ++nk) {
                    newStrides[nk] = newStrides[nk - 1] * newShape[nk - 1];
                }
            } else {
                newStrides[nj - 1] = oldstrides[oj - 1];
                for (nk = nj - 1; nk > ni; --nk) {
                    newStrides[nk - 1] = newStrides[nk] * newShape[nk];
                }
            }
            ni = nj++;
            oi = oj++;
        }
        long last_stride = ni >= 1 ? newStrides[ni - 1] : (long)stride[shape.length - 1];
        if (isFOrder && ni >= 1) {
            last_stride *= newShape[ni - 1];
        }
        for (nk = ni; nk < newShapeRank; ++nk) {
            newStrides[nk] = last_stride;
        }
        if (newStrides[newShapeRank - 1] >= Integer.MAX_VALUE) {
            throw new IllegalArgumentException("Element size can not be >= Integer.MAX_VALUE");
        }
        return (int)newStrides[newShapeRank - 1];
    }

    public static long elementWiseStride(long[] shape, long[] stride, boolean isFOrder) {
        int nk;
        int ni;
        int oi;
        if (shape.length == 0 && stride.length == 0) {
            return 1L;
        }
        if (shape.length == 1 && stride.length == 1) {
            return stride[0];
        }
        long[] olddims = ArrayUtil.copy((long[])shape);
        long[] oldstrides = ArrayUtil.copy((long[])stride);
        long[] newStrides = new long[stride.length];
        int oldnd = 0;
        int newShapeRank = 2;
        long[] newShape = new long[shape.length];
        newShape[0] = 1L;
        newShape[1] = ArrayUtil.prodLong((long[])shape);
        for (oi = 0; oi < shape.length; ++oi) {
            if (shape[oi] == 1L) continue;
            olddims[oldnd] = shape[oi];
            oldstrides[oldnd] = stride[oi];
            ++oldnd;
        }
        long np = 1L;
        for (ni = 0; ni < newShapeRank; ++ni) {
            np *= newShape[ni];
        }
        long op = 1L;
        for (oi = 0; oi < oldnd; ++oi) {
            op *= olddims[oi];
        }
        if (np != op) {
            return 0L;
        }
        if (np == 0L) {
            return 0L;
        }
        oi = 0;
        int oj = 1;
        ni = 0;
        int nj = 1;
        while (ni < newShapeRank && oi < oldnd) {
            np = newShape[ni];
            op = olddims[oi];
            while (np != op) {
                if (np < op) {
                    np *= newShape[nj++];
                    continue;
                }
                op *= olddims[oj++];
            }
            for (int ok = oi; ok < oj - 1; ++ok) {
                if (!(isFOrder ? oldstrides[ok + 1] != olddims[ok] * oldstrides[ok] : oldstrides[ok] != olddims[ok + 1] * oldstrides[ok + 1])) continue;
                return 0L;
            }
            if (isFOrder) {
                newStrides[ni] = oldstrides[oi];
                for (nk = ni + 1; nk < nj; ++nk) {
                    newStrides[nk] = newStrides[nk - 1] * newShape[nk - 1];
                }
            } else {
                newStrides[nj - 1] = oldstrides[oj - 1];
                for (nk = nj - 1; nk > ni; --nk) {
                    newStrides[nk - 1] = newStrides[nk] * newShape[nk];
                }
            }
            ni = nj++;
            oi = oj++;
        }
        long last_stride = ni >= 1 ? newStrides[ni - 1] : stride[shape.length - 1];
        if (isFOrder && ni >= 1) {
            last_stride *= newShape[ni - 1];
        }
        for (nk = ni; nk < newShapeRank; ++nk) {
            newStrides[nk] = last_stride;
        }
        if (newStrides[newShapeRank - 1] >= Integer.MAX_VALUE) {
            throw new IllegalArgumentException("Element size can not be >= Integer.MAX_VALUE");
        }
        return newStrides[newShapeRank - 1];
    }

    public static INDArray newShapeNoCopy(INDArray arr, int[] newShape, boolean isFOrder) {
        return Shape.newShapeNoCopy(arr, ArrayUtil.toLongArray((int[])newShape), isFOrder);
    }

    public static INDArray newShapeNoCopy(INDArray arr, long[] newShape, boolean isFOrder) {
        int nk;
        int ni;
        int oi;
        long[] olddims = ArrayUtil.copy((long[])arr.shape());
        long[] oldstrides = ArrayUtil.copy((long[])arr.stride());
        long[] newStrides = new long[newShape.length];
        int oldnd = 0;
        for (oi = 0; oi < arr.rank(); ++oi) {
            if (arr.size(oi) == 1L) continue;
            olddims[oldnd] = arr.size(oi);
            oldstrides[oldnd] = arr.stride(oi);
            ++oldnd;
        }
        long np = 1L;
        for (ni = 0; ni < newShape.length; ++ni) {
            np *= newShape[ni];
        }
        long op = 1L;
        for (oi = 0; oi < oldnd; ++oi) {
            op *= olddims[oi];
        }
        if (np != op) {
            return null;
        }
        if (np == 0L) {
            return null;
        }
        oi = 0;
        int oj = 1;
        ni = 0;
        int nj = 1;
        while (ni < newShape.length && oi < oldnd) {
            np = newShape[ni];
            op = olddims[oi];
            while (np != op) {
                if (np < op) {
                    np *= newShape[nj++];
                    continue;
                }
                op *= olddims[oj++];
            }
            for (int ok = oi; ok < oj - 1; ++ok) {
                if (!(isFOrder ? oldstrides[ok + 1] != olddims[ok] * oldstrides[ok] : oldstrides[ok] != olddims[ok + 1] * oldstrides[ok + 1])) continue;
                return null;
            }
            if (isFOrder) {
                newStrides[ni] = oldstrides[oi];
                for (nk = ni + 1; nk < nj; ++nk) {
                    newStrides[nk] = newStrides[nk - 1] * newShape[nk - 1];
                }
            } else {
                newStrides[nj - 1] = oldstrides[oj - 1];
                for (nk = nj - 1; nk > ni; --nk) {
                    newStrides[nk - 1] = newStrides[nk] * newShape[nk];
                }
            }
            ni = nj++;
            oi = oj++;
        }
        long last_stride = ni >= 1 ? newStrides[ni - 1] : 1L;
        if (isFOrder && ni >= 1) {
            last_stride *= newShape[ni - 1];
        }
        for (nk = ni; nk < newShape.length; ++nk) {
            newStrides[nk] = last_stride;
        }
        INDArray ret = Nd4j.create(arr.data(), newShape, newStrides, arr.offset(), isFOrder ? (char)'f' : 'c');
        return ret;
    }

    public static boolean cOrFortranOrder(long[] shape, long[] stride, long elementStride) {
        long dim;
        int i;
        boolean cContiguous = true;
        boolean isFortran = true;
        long sd = 1L;
        for (i = shape.length - 1; i >= 0; --i) {
            dim = shape[i];
            if (stride[i] != sd) {
                cContiguous = false;
                break;
            }
            if (dim == 0L) break;
            sd *= dim;
        }
        sd = elementStride;
        for (i = 0; i < shape.length; ++i) {
            dim = shape[i];
            if (stride[i] != sd) {
                isFortran = false;
            }
            if (dim == 0L) break;
            sd *= dim;
        }
        return cContiguous || isFortran;
    }

    @Deprecated
    public static boolean cOrFortranOrder(int[] shape, int[] stride, int elementStride) {
        return Shape.cOrFortranOrder(ArrayUtil.toLongArray((int[])shape), ArrayUtil.toLongArray((int[])stride), (long)elementStride);
    }

    public static char getOrder(int[] shape, int[] stride, int elementStride) {
        int dim;
        int i;
        boolean cContiguous = true;
        boolean isFortran = true;
        int sd = 1;
        for (i = shape.length - 1; i >= 0; --i) {
            dim = shape[i];
            if (stride[i] != sd) {
                cContiguous = false;
                break;
            }
            if (dim == 0) break;
            sd *= dim;
        }
        sd = elementStride;
        for (i = 0; i < shape.length; ++i) {
            dim = shape[i];
            if (stride[i] != sd) {
                isFortran = false;
            }
            if (dim == 0) break;
            sd *= dim;
        }
        if (isFortran && cContiguous) {
            return 'a';
        }
        if (isFortran && !cContiguous) {
            return 'f';
        }
        if (!isFortran && !cContiguous) {
            return 'c';
        }
        return 'c';
    }

    public static char getOrder(long[] shape, long[] stride, long elementStride) {
        long dim;
        int i;
        boolean cContiguous = true;
        boolean isFortran = true;
        long sd = 1L;
        for (i = shape.length - 1; i >= 0; --i) {
            dim = shape[i];
            if (stride[i] != sd) {
                cContiguous = false;
                break;
            }
            if (dim == 0L) break;
            sd *= dim;
        }
        sd = elementStride;
        for (i = 0; i < shape.length; ++i) {
            dim = shape[i];
            if (stride[i] != sd) {
                isFortran = false;
            }
            if (dim == 0L) break;
            sd *= dim;
        }
        if (isFortran && cContiguous) {
            return 'a';
        }
        if (isFortran && !cContiguous) {
            return 'f';
        }
        boolean stridesAscending = true;
        for (int j = 1; j < stride.length; ++j) {
            stridesAscending &= stride[j] >= stride[j - 1];
        }
        if (stridesAscending) {
            return 'f';
        }
        return 'c';
    }

    public static char getOrder(INDArray arr) {
        return Shape.getOrder(arr.shape(), arr.stride(), 1L);
    }

    public static long sub2Ind(int[] shape, int[] indices) {
        long index = 0L;
        int shift = 1;
        for (int i = 0; i < shape.length; ++i) {
            index += (long)(shift * indices[i]);
            shift *= shape[i];
        }
        return index;
    }

    public static int[] ind2sub(int[] shape, long index, long numIndices) {
        long denom = numIndices;
        int[] ret = new int[shape.length];
        for (int i = ret.length - 1; i >= 0; --i) {
            if (index / (denom /= (long)shape[i]) >= Integer.MAX_VALUE) {
                throw new IllegalArgumentException("Dimension can not be >= Integer.MAX_VALUE");
            }
            ret[i] = (int)(index / denom);
            index %= denom;
        }
        return ret;
    }

    public static long[] ind2sub(long[] shape, long index, long numIndices) {
        long denom = numIndices;
        long[] ret = new long[shape.length];
        for (int i = ret.length - 1; i >= 0; --i) {
            if (index / (denom /= shape[i]) >= Integer.MAX_VALUE) {
                throw new IllegalArgumentException("Dimension can not be >= Integer.MAX_VALUE");
            }
            ret[i] = index / denom;
            index %= denom;
        }
        return ret;
    }

    public static int[] ind2sub(int[] shape, long index) {
        return Shape.ind2sub(shape, index, ArrayUtil.prodLong((int[])shape));
    }

    public static long[] ind2sub(long[] shape, long index) {
        return Shape.ind2sub(shape, index, ArrayUtil.prodLong((long[])shape));
    }

    public static long[] ind2sub(INDArray arr, long index) {
        if (arr.rank() == 1) {
            return new long[]{(int)index};
        }
        return Shape.ind2sub(arr.shape(), index, ArrayUtil.prodLong((long[])arr.shape()));
    }

    public static int[] ind2subC(int[] shape, long index, long numIndices) {
        long denom = numIndices;
        int[] ret = new int[shape.length];
        for (int i = 0; i < shape.length; ++i) {
            if (index / (denom /= (long)shape[i]) >= Integer.MAX_VALUE) {
                throw new IllegalArgumentException("Dimension can not be >= Integer.MAX_VALUE");
            }
            ret[i] = (int)(index / denom);
            index %= denom;
        }
        return ret;
    }

    public static long[] ind2subC(long[] shape, long index, long numIndices) {
        long denom = numIndices;
        long[] ret = new long[shape.length];
        for (int i = 0; i < shape.length; ++i) {
            if (index / (denom /= shape[i]) >= Integer.MAX_VALUE) {
                throw new IllegalArgumentException("Dimension can not be >= Integer.MAX_VALUE");
            }
            ret[i] = index / denom;
            index %= denom;
        }
        return ret;
    }

    public static int[] ind2subC(int[] shape, long index) {
        return Shape.ind2subC(shape, index, ArrayUtil.prodLong((int[])shape));
    }

    public static long[] ind2subC(long[] shape, long index) {
        return Shape.ind2subC(shape, index, ArrayUtil.prodLong((long[])shape));
    }

    public static long[] ind2subC(INDArray arr, long index) {
        if (arr.rank() == 1) {
            return new long[]{index};
        }
        return Shape.ind2subC(arr.shape(), index, ArrayUtil.prodLong((long[])arr.shape()));
    }

    public static void assertShapeLessThan(int[] shape, int[] lessThan) {
        if (shape.length != lessThan.length) {
            throw new IllegalArgumentException("Shape length must be == less than length");
        }
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] < lessThan[i]) continue;
            throw new IllegalStateException("Shape[" + i + "] should be less than lessThan[" + i + "]");
        }
    }

    public static void assertShapeLessThan(long[] shape, long[] lessThan) {
        if (shape.length != lessThan.length) {
            throw new IllegalArgumentException("Shape length must be == less than length");
        }
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] < lessThan[i]) continue;
            throw new IllegalStateException("Shape[" + i + "] should be less than lessThan[" + i + "]");
        }
    }

    public static int[] newStrides(int[] strides, int newLength, INDArrayIndex[] indexes) {
        if (strides.length > newLength) {
            int[] newStrides = new int[strides.length - 1];
            System.arraycopy(strides, 1, newStrides, 0, newStrides.length);
            strides = newStrides;
        }
        return strides;
    }

    public static boolean strideDescendingCAscendingF(INDArray array) {
        if (array.rank() <= 1) {
            return true;
        }
        long[] strides = array.stride();
        if (array.isVector() && strides[0] == 1L && strides[1] == 1L) {
            return true;
        }
        char order = array.ordering();
        if (order == 'c') {
            for (int i = 1; i < strides.length; ++i) {
                if (strides[i - 1] > strides[i]) continue;
                return false;
            }
            return true;
        }
        if (order == 'f') {
            for (int i = 1; i < strides.length; ++i) {
                if (strides[i - 1] < strides[i]) continue;
                return false;
            }
            return true;
        }
        if (order == 'a') {
            return true;
        }
        throw new RuntimeException("Invalid order: not c or f (is: " + order + ")");
    }

    public static int length(IntBuffer buffer) {
        int ret = 1;
        IntBuffer shape = Shape.shapeOf(buffer);
        int rank = Shape.rank(buffer);
        for (int i = 0; i < rank; ++i) {
            ret *= shape.get(i);
        }
        return ret;
    }

    public static long length(DataBuffer buffer) {
        long ret = 1L;
        long[] rr = buffer.asLong();
        DataBuffer shape = Shape.shapeOf(buffer);
        int rank = Shape.rank(buffer);
        for (int i = 0; i < rank; ++i) {
            ret *= shape.getLong((long)i);
        }
        return ret;
    }

    public static long length(int[] buffer) {
        long ret = 1L;
        int limit = Shape.rank(buffer) + 1;
        for (int i = 1; i < limit; ++i) {
            ret *= (long)buffer[i];
        }
        return ret;
    }

    public static long length(long[] buffer) {
        long ret = 1L;
        int limit = Shape.rank(buffer) + 1;
        for (int i = 1; i < limit; ++i) {
            ret *= buffer[i];
        }
        return ret;
    }

    public static int rank(DataBuffer buffer) {
        return buffer.getInt(0L);
    }

    public static int rank(IntBuffer buffer) {
        IntBuffer buffer2 = buffer;
        IntBuffer ret = (IntBuffer)((Buffer)buffer2).position(0);
        return ret.get(0);
    }

    public static int rank(LongBuffer buffer) {
        LongBuffer buffer2 = buffer;
        LongBuffer ret = (LongBuffer)((Buffer)buffer2).position(0);
        return (int)ret.get(0);
    }

    public static int rank(long[] buffer) {
        return (int)buffer[0];
    }

    public static int rank(int[] buffer) {
        return buffer[0];
    }

    public static int size(IntBuffer buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer.get(1 + dimension);
    }

    public static long size(LongBuffer buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer.get(1 + dimension);
    }

    public static int size(DataBuffer buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer.getInt((long)(1 + dimension));
    }

    public static int size(int[] buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer[1 + dimension];
    }

    public static long size(long[] buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer[1 + dimension];
    }

    public static int sizeUnsafe(DataBuffer buffer, int dimension) {
        return buffer.getInt((long)(1 + dimension));
    }

    public static int sizeUnsafe(int[] buffer, int dimension) {
        return buffer[1 + dimension];
    }

    public static long sizeUnsafe(long[] buffer, int dimension) {
        return buffer[1 + dimension];
    }

    public static long[] shape(IntBuffer buffer) {
        long[] ret = new long[Shape.rank(buffer)];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = buffer.get(1 + i);
        }
        return ret;
    }

    public static long[] shape(LongBuffer buffer) {
        long[] ret = new long[Shape.rank(buffer)];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = buffer.get(1 + i);
        }
        return ret;
    }

    public static long[] shape(DataBuffer buffer) {
        long[] ret = new long[Shape.rank(buffer)];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = buffer.getInt((long)(1 + i));
        }
        return ret;
    }

    public static int[] shape(int[] buffer) {
        int[] ret = new int[Shape.rank(buffer)];
        System.arraycopy(buffer, 1, ret, 0, ret.length);
        return ret;
    }

    public static long[] shape(long[] buffer) {
        long[] ret = new long[Shape.rank(buffer)];
        System.arraycopy(buffer, 1, ret, 0, ret.length);
        return ret;
    }

    public static int stride(IntBuffer buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer.get(1 + rank + dimension);
    }

    public static long stride(LongBuffer buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer.get(1 + rank + dimension);
    }

    public static int stride(DataBuffer buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer.getInt((long)(1 + rank + dimension));
    }

    public static int stride(int[] buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer[1 + rank + dimension];
    }

    public static long stride(long[] buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer[1 + rank + dimension];
    }

    public static long[] strideArr(DataBuffer buffer) {
        long[] ret = new long[Shape.rank(buffer)];
        DataBuffer stride = Shape.stride(buffer);
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = stride.getInt((long)i);
        }
        return ret;
    }

    public static int strideUnsafe(DataBuffer buffer, int dimension, int rank) {
        return buffer.getInt((long)(1 + rank + dimension));
    }

    public static int strideUnsafe(int[] buffer, int dimension, int rank) {
        return buffer[1 + rank + dimension];
    }

    public static long strideUnsafe(long[] buffer, int dimension, int rank) {
        return buffer[1 + rank + dimension];
    }

    public static int shapeInfoLength(int rank) {
        return rank * 2 + 4;
    }

    public static int shapeInfoLength(long[] shape) {
        return Shape.shapeInfoLength((int)shape[0]);
    }

    public static IntBuffer stride(IntBuffer buffer) {
        int rank = Shape.rank(buffer);
        IntBuffer buffer2 = buffer;
        IntBuffer ret = (IntBuffer)((Buffer)buffer2).position(1 + rank);
        return ret.slice();
    }

    public static LongBuffer stride(LongBuffer buffer) {
        int rank = Shape.rank(buffer);
        LongBuffer buffer2 = buffer;
        LongBuffer ret = (LongBuffer)((Buffer)buffer2).position(1 + rank);
        return ret.slice();
    }

    public static DataBuffer stride(DataBuffer buffer) {
        int rank = Shape.rank(buffer);
        return Nd4j.createBuffer(buffer, (long)(1 + rank), (long)rank);
    }

    public static int[] stride(int[] buffer) {
        int rank = Shape.rank(buffer);
        int[] ret = new int[rank];
        for (int i = 0; i < rank; ++i) {
            ret[i] = buffer[1 + rank + i];
        }
        return ret;
    }

    public static long[] stride(long[] buffer) {
        int rank = Shape.rank(buffer);
        long[] ret = new long[rank];
        for (int i = 0; i < rank; ++i) {
            ret[i] = buffer[1 + rank + i];
        }
        return ret;
    }

    public static DataBuffer shapeOf(DataBuffer buffer) {
        int rank = (int)buffer.getLong(0L);
        return Nd4j.createBuffer(buffer, 1L, (long)rank);
    }

    public static IntBuffer shapeOf(IntBuffer buffer) {
        IntBuffer buffer2 = buffer;
        IntBuffer ret = (IntBuffer)((Buffer)buffer2).position(1);
        return ret.slice();
    }

    public static LongBuffer shapeOf(LongBuffer buffer) {
        LongBuffer buffer2 = buffer;
        LongBuffer ret = (LongBuffer)((Buffer)buffer2).position(1);
        return ret.slice();
    }

    public static int[] shapeOf(int[] buffer) {
        int rank = buffer[0];
        return Arrays.copyOfRange(buffer, 1, 1 + rank);
    }

    public static long[] shapeOf(long[] buffer) {
        int rank = (int)buffer[0];
        return Arrays.copyOfRange(buffer, 1, 1 + rank);
    }

    public static int[] stridesOf(int[] buffer) {
        int rank = buffer[0];
        return Arrays.copyOfRange(buffer, 1 + rank, 1 + rank * 2);
    }

    public static long[] stridesOf(long[] buffer) {
        int rank = (int)buffer[0];
        return Arrays.copyOfRange(buffer, 1 + rank, 1 + rank * 2);
    }

    public static int[] flags(DataBuffer buffer) {
        int length = buffer.getInt(0L);
        int[] ret = new int[length];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = buffer.getInt((long)(1 + i));
        }
        return ret;
    }

    public static int[] sparseOffsets(DataBuffer buffer) {
        int flagsLength = buffer.getInt(0L);
        int offLength = buffer.getInt((long)(flagsLength + 1));
        int[] ret = new int[offLength];
        for (int i = 0; i < offLength; ++i) {
            ret[i] = buffer.getInt((long)(i + flagsLength + 2));
        }
        return ret;
    }

    public static int[] hiddenDimension(DataBuffer buffer) {
        int flagsLength = buffer.getInt(0L);
        int offLength = buffer.getInt((long)(flagsLength + 1));
        int hiddenDimLength = buffer.getInt((long)(flagsLength + offLength + 2));
        int[] ret = new int[hiddenDimLength];
        for (int i = 0; i < hiddenDimLength; ++i) {
            ret[i] = buffer.getInt((long)(i + flagsLength + offLength + 3));
        }
        return ret;
    }

    public static int underlyingRank(DataBuffer buffer) {
        int flagsLength = buffer.getInt(0L);
        int offLength = buffer.getInt((long)(flagsLength + 1));
        int hiddenDimLength = buffer.getInt((long)(flagsLength + offLength + 2));
        return buffer.getInt((long)(flagsLength + offLength + hiddenDimLength + 3));
    }

    public static String shapeToString(INDArray arr) {
        return Shape.shapeToString(arr.shapeInfo());
    }

    public static String shapeToString(IntBuffer buffer) {
        int i;
        IntBuffer shapeBuff = Shape.shapeOf(buffer);
        int rank = Shape.rank(buffer);
        IntBuffer strideBuff = Shape.stride(buffer);
        StringBuilder sb = new StringBuilder();
        sb.append("Rank: " + rank + ",");
        sb.append("Offset: " + Shape.offset(buffer) + "\n");
        sb.append(" Order: " + Shape.order(buffer));
        sb.append(" Shape: [");
        for (i = 0; i < rank; ++i) {
            sb.append(shapeBuff.get(i));
            if (i >= rank - 1) continue;
            sb.append(",");
        }
        sb.append("], ");
        sb.append(" stride: [");
        for (i = 0; i < rank; ++i) {
            sb.append(strideBuff.get(i));
            if (i >= rank - 1) continue;
            sb.append(",");
        }
        sb.append("]");
        return sb.toString();
    }

    public static String shapeToString(LongBuffer buffer) {
        int i;
        int length = buffer.capacity();
        long options = buffer.get(length - 3);
        LongBuffer shapeBuff = Shape.shapeOf(buffer);
        int rank = Shape.rank(buffer);
        LongBuffer strideBuff = Shape.stride(buffer);
        StringBuilder sb = new StringBuilder();
        sb.append("Rank: ").append(rank).append(",").append(" DataType: ").append(ArrayOptionsHelper.dataType(options)).append(",").append(" Offset: ").append(Shape.offset(buffer)).append(",").append(" Order: ").append(Shape.order(buffer)).append(",").append(" Shape: [");
        for (i = 0; i < rank; ++i) {
            sb.append(shapeBuff.get(i));
            if (i >= rank - 1) continue;
            sb.append(",");
        }
        sb.append("], ");
        sb.append(" Stride: [");
        for (i = 0; i < rank; ++i) {
            sb.append(strideBuff.get(i));
            if (i >= rank - 1) continue;
            sb.append(",");
        }
        sb.append("]");
        return sb.toString();
    }

    public static String shapeToStringShort(INDArray arr) {
        long[] s = arr.shape();
        return arr.dataType() + "," + (s == null ? "[]" : Arrays.toString(s).replace(" ", "")) + "," + arr.ordering();
    }

    @Deprecated
    public static int offset(DataBuffer buffer) {
        return 0;
    }

    public static long options(long[] buffer) {
        int length = Shape.shapeInfoLength(Shape.rank(buffer));
        long ret = buffer[length - 3];
        return ret;
    }

    public static long extras(long[] buffer) {
        return Shape.options(buffer);
    }

    @Deprecated
    public static int offset(int[] buffer) {
        return 0;
    }

    @Deprecated
    public static int offset(long[] buffer) {
        return 0;
    }

    @Deprecated
    public static int offset(IntBuffer buffer) {
        return 0;
    }

    @Deprecated
    public static long offset(LongBuffer buffer) {
        return 0L;
    }

    public static int elementWiseStride(DataBuffer buffer) {
        int length2 = Shape.shapeInfoLength(buffer.getInt(0L));
        return buffer.getInt((long)(length2 - 2));
    }

    public static int elementWiseStride(IntBuffer buffer) {
        int length2 = Shape.shapeInfoLength(buffer.get(0));
        return buffer.get(length2 - 2);
    }

    public static long elementWiseStride(long[] buffer) {
        int length2 = Shape.shapeInfoLength(buffer);
        return buffer[length2 - 2];
    }

    public static void setElementWiseStride(IntBuffer buffer, int elementWiseStride) {
        int length2 = Shape.shapeInfoLength(buffer.get(0));
        buffer.put(length2 - 2, elementWiseStride);
    }

    public static void setElementWiseStride(DataBuffer buffer, int elementWiseStride) {
        int length2 = Shape.shapeInfoLength(Shape.rank(buffer));
        buffer.put((long)(length2 - 2), elementWiseStride);
    }

    public static String bufferToString(IntBuffer buffer) {
        StringBuilder builder = new StringBuilder();
        int rank = buffer.get(0);
        builder.append("[ ").append(rank).append(", ");
        for (int p = 1; p < rank * 2 + 4; ++p) {
            builder.append(buffer.get(p));
            if (p >= rank * 2 + 4 - 1) continue;
            builder.append(", ");
        }
        builder.append("]");
        return builder.toString();
    }

    public static char order(IntBuffer buffer) {
        int length = Shape.shapeInfoLength(Shape.rank(buffer));
        return (char)buffer.get(length - 1);
    }

    public static char order(LongBuffer buffer) {
        int length = Shape.shapeInfoLength(Shape.rank(buffer));
        return (char)buffer.get(length - 1);
    }

    public static char order(DataBuffer buffer) {
        int length = Shape.shapeInfoLength(Shape.rank(buffer));
        return (char)buffer.getInt((long)(length - 1));
    }

    public static char order(int[] buffer) {
        int length = Shape.shapeInfoLength(Shape.rank(buffer));
        return (char)buffer[length - 1];
    }

    public static char order(long[] buffer) {
        int length = Shape.shapeInfoLength(Shape.rank(buffer));
        return (char)buffer[length - 1];
    }

    @Deprecated
    public static void setOrder(IntBuffer buffer, char order) {
        int length = Shape.shapeInfoLength(Shape.rank(buffer));
        buffer.put(length - 1, order);
        throw new RuntimeException("setOrder called");
    }

    public static DataBuffer createShapeInformation(int[] shape, int[] stride, long offset, int elementWiseStride, char order) {
        int e;
        if (shape.length != stride.length) {
            throw new IllegalStateException("Shape and stride must be the same length");
        }
        int rank = shape.length;
        int[] shapeBuffer = new int[rank * 2 + 4];
        shapeBuffer[0] = rank;
        int count = 1;
        for (e = 0; e < shape.length; ++e) {
            shapeBuffer[count++] = shape[e];
        }
        for (e = 0; e < stride.length; ++e) {
            shapeBuffer[count++] = stride[e];
        }
        shapeBuffer[count++] = (int)offset;
        shapeBuffer[count++] = elementWiseStride;
        shapeBuffer[count] = order;
        DataBuffer ret = Nd4j.createBufferDetached(shapeBuffer);
        ret.setConstant(true);
        return ret;
    }

    public static DataBuffer createShapeInformation(long[] shape, long[] stride, long elementWiseStride, char order, DataType dataType) {
        int e;
        long offset = 0L;
        ArrayOptionsHelper.setOptionBit(offset, dataType);
        if (shape.length != stride.length) {
            throw new IllegalStateException("Shape and stride must be the same length");
        }
        int rank = shape.length;
        long[] shapeBuffer = new long[Shape.shapeInfoLength(rank)];
        shapeBuffer[0] = rank;
        int count = 1;
        for (e = 0; e < shape.length; ++e) {
            shapeBuffer[count++] = shape[e];
        }
        for (e = 0; e < stride.length; ++e) {
            shapeBuffer[count++] = stride[e];
        }
        shapeBuffer[count++] = offset;
        shapeBuffer[count++] = elementWiseStride;
        shapeBuffer[count] = order;
        DataBuffer ret = Nd4j.createBufferDetached(shapeBuffer);
        ret.setConstant(true);
        return ret;
    }

    public static DataBuffer createShapeInformation(long[] shape, long[] stride, long elementWiseStride, char order, long extras) {
        int e;
        if (shape.length != stride.length) {
            throw new IllegalStateException("Shape and stride must be the same length");
        }
        int rank = shape.length;
        long[] shapeBuffer = new long[Shape.shapeInfoLength(rank)];
        shapeBuffer[0] = rank;
        int count = 1;
        for (e = 0; e < shape.length; ++e) {
            shapeBuffer[count++] = shape[e];
        }
        for (e = 0; e < stride.length; ++e) {
            shapeBuffer[count++] = stride[e];
        }
        shapeBuffer[count++] = extras;
        shapeBuffer[count++] = elementWiseStride;
        shapeBuffer[count] = order;
        DataBuffer ret = Nd4j.createBufferDetached(shapeBuffer);
        ret.setConstant(true);
        return ret;
    }

    public static DataBuffer createSparseInformation(int[] flags, long[] sparseOffsets, int[] hiddenDimensions, int underlyingRank) {
        int flagLength = flags.length;
        int offsetsLength = sparseOffsets.length;
        int hiddenDimLength = hiddenDimensions.length;
        int totalLength = flagLength + offsetsLength + hiddenDimLength + 4;
        ArrayList<Integer> accu = new ArrayList<Integer>(totalLength);
        accu.add(flagLength);
        for (int flag : flags) {
            accu.add(flag);
        }
        accu.add(offsetsLength);
        for (long off : sparseOffsets) {
            accu.add((int)off);
        }
        accu.add(hiddenDimLength);
        for (int dim : hiddenDimensions) {
            accu.add(dim);
        }
        accu.add(underlyingRank);
        return Nd4j.createBuffer(Ints.toArray(accu));
    }

    public static IntBuffer toBuffer(int ... arr) {
        ByteBuffer directBuffer = ByteBuffer.allocateDirect(arr.length * 4).order(ByteOrder.nativeOrder());
        IntBuffer buffer = directBuffer.asIntBuffer();
        for (int i = 0; i < arr.length; ++i) {
            buffer.put(i, arr[i]);
        }
        return buffer;
    }

    public static String toString(IntBuffer buffer) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < buffer.capacity(); ++i) {
            sb.append(buffer.get(i));
            if (i >= buffer.capacity() - 1) continue;
            sb.append(",");
        }
        return sb.toString();
    }

    public static String toString(DataBuffer buffer) {
        return buffer.toString();
    }

    public static boolean wholeArrayDimension(int ... arr) {
        return arr == null || arr.length == 0 || arr.length == 1 && arr[0] == Integer.MAX_VALUE;
    }

    public static int[] uniquify(int[] array) {
        if (array.length <= 1) {
            return array;
        }
        LinkedHashSet<Integer> ints = new LinkedHashSet<Integer>();
        for (int v : array) {
            ints.add(v);
        }
        return Ints.toArray(ints);
    }

    public static int[] normalizeAxis(int rank, int ... axis) {
        if (axis == null || axis.length == 0) {
            return new int[]{Integer.MAX_VALUE};
        }
        if (rank == 0) {
            if (axis.length != 1 || axis[0] != 0 && axis[0] != Integer.MAX_VALUE) {
                throw new ND4JIllegalStateException("Array axis for scalar (rank 0) array invalid: rank " + Arrays.toString(axis));
            }
            if (axis[0] == Integer.MAX_VALUE) {
                return axis;
            }
            return new int[]{Integer.MAX_VALUE};
        }
        int[] tmp = new int[axis.length];
        int cnt = 0;
        for (int v : axis) {
            int t;
            int n = t = v < 0 ? v + rank : v;
            if (t >= rank && t != Integer.MAX_VALUE || t < 0) {
                throw new ND4JIllegalStateException("Axis array " + Arrays.toString(axis) + " contains values above array rank (rank=" + rank + ")");
            }
            tmp[cnt++] = t;
        }
        if (tmp.length > 1) {
            Arrays.sort(tmp);
        }
        return Shape.uniquify(tmp);
    }

    public static boolean contentEquals(int[] arr, DataBuffer other) {
        for (int i = 0; i < arr.length; ++i) {
            if (other.getInt((long)i) == arr[i]) continue;
            return false;
        }
        return true;
    }

    public static boolean contentEquals(long[] arr, long[] other) {
        for (int i = 0; i < arr.length; ++i) {
            if (other[i] == arr[i]) continue;
            return false;
        }
        return true;
    }

    public static boolean contentEquals(long[] arr, DataBuffer other) {
        for (int i = 0; i < arr.length; ++i) {
            if (other.getLong((long)i) == arr[i]) continue;
            return false;
        }
        return true;
    }

    public static boolean contentEquals(int[] arr, IntBuffer other) {
        for (int i = 0; i < arr.length; ++i) {
            IntBuffer buffer2 = other;
            ((Buffer)buffer2).position(i);
            if (arr[i] == other.get()) continue;
            return false;
        }
        return true;
    }

    public static boolean contentEquals(long[] arr, IntBuffer other) {
        for (int i = 0; i < arr.length; ++i) {
            IntBuffer buffer2 = other;
            ((Buffer)buffer2).position(i);
            if (arr[i] == (long)other.get()) continue;
            return false;
        }
        return true;
    }

    public static boolean isContiguousInBuffer(INDArray in) {
        long[] stridesIfContiguous;
        long dLength;
        long length = in.length();
        if (length == (dLength = in.data().length())) {
            return true;
        }
        char order = in.ordering();
        long[] shape = in.shape();
        if (order == 'f') {
            stridesIfContiguous = ArrayUtil.calcStridesFortran((long[])shape);
        } else if (order == 'c') {
            stridesIfContiguous = ArrayUtil.calcStrides((long[])shape);
        } else if (order == 'a') {
            stridesIfContiguous = new long[]{1L, 1L};
        } else {
            throw new RuntimeException("Invalid order: not c or f (is: " + order + ")");
        }
        return Arrays.equals(in.stride(), stridesIfContiguous);
    }

    public static INDArray toMmulCompatible(INDArray input) {
        if (input.rank() != 2) {
            throw new IllegalArgumentException("Input must be rank 2 (matrix)");
        }
        boolean doCopy = false;
        if (input.ordering() == 'c' && ((long)input.stride(0) != input.size(1) || input.stride(1) != 1)) {
            doCopy = true;
        } else if (input.ordering() == 'f' && (input.stride(0) != 1 || (long)input.stride(1) != input.size(0))) {
            doCopy = true;
        }
        if (doCopy) {
            return Shape.toOffsetZeroCopyAnyOrder(input);
        }
        return input;
    }

    public static int rankFromShape(int[] shape) {
        if (shape == null) {
            throw new ND4JIllegalStateException("Cannot get rank from null shape array");
        }
        return shape.length;
    }

    public static int rankFromShape(long[] shape) {
        if (shape == null) {
            throw new ND4JIllegalStateException("Cannot get rank from null shape array");
        }
        return shape.length;
    }

    public static void assertBroadcastable(@NonNull INDArray x, @NonNull INDArray y) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (y == null) {
            throw new NullPointerException("y is marked @NonNull but is null");
        }
        Shape.assertBroadcastable(x.shape(), y.shape());
    }

    public static void assertBroadcastable(@NonNull int[] x, @NonNull int[] y) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (y == null) {
            throw new NullPointerException("y is marked @NonNull but is null");
        }
        if (!Shape.areShapesBroadcastable(x, y)) {
            throw new ND4JIllegalStateException("Arrays are different shape and are not broadcastable. Array 1 shape = " + Arrays.toString(x) + ", array 2 shape = " + Arrays.toString(y));
        }
    }

    public static void assertBroadcastable(@NonNull long[] x, @NonNull long[] y) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (y == null) {
            throw new NullPointerException("y is marked @NonNull but is null");
        }
        Shape.assertBroadcastable(x, y, null);
    }

    public static void assertBroadcastable(@NonNull long[] x, @NonNull long[] y, Class<?> opClass) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (y == null) {
            throw new NullPointerException("y is marked @NonNull but is null");
        }
        if (!Shape.areShapesBroadcastable(x, y)) {
            throw new ND4JIllegalStateException("Arrays are different shape and are not broadcastable. Array 1 shape = " + Arrays.toString(x) + ", array 2 shape = " + Arrays.toString(y) + (opClass == null ? "" : " - op: " + opClass.getName()));
        }
    }

    public static boolean areShapesBroadcastable(@NonNull int[] x, @NonNull int[] y) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (y == null) {
            throw new NullPointerException("y is marked @NonNull but is null");
        }
        int minRank = Math.min(x.length, y.length);
        for (int i = -1; i >= -minRank; --i) {
            if (x[x.length + i] == y[y.length + i] || x[x.length + i] == 1 || y[y.length + i] == 1) continue;
            return false;
        }
        return true;
    }

    public static boolean areShapesBroadcastable(@NonNull long[] x, @NonNull long[] y) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (y == null) {
            throw new NullPointerException("y is marked @NonNull but is null");
        }
        int minRank = Math.min(x.length, y.length);
        for (int i = -1; i >= -minRank; --i) {
            if (x[x.length + i] == y[y.length + i] || x[x.length + i] == 1L || y[y.length + i] == 1L) continue;
            return false;
        }
        return true;
    }

    public static long lengthOf(long[] shape) {
        if (shape.length == 0) {
            return 1L;
        }
        return ArrayUtil.prodLong((long[])shape);
    }

    public static long lengthOfBuffer(@NonNull long[] shape, @NonNull long[] stride) {
        if (shape == null) {
            throw new NullPointerException("shape is marked @NonNull but is null");
        }
        if (stride == null) {
            throw new NullPointerException("stride is marked @NonNull but is null");
        }
        Preconditions.checkArgument((shape.length == stride.length ? 1 : 0) != 0, (String)"Shape and strides must be same length: shape %s, stride %s", (Object)shape, (Object)stride);
        long length = 1L;
        for (int i = 0; i < shape.length; ++i) {
            length += (shape[i] - 1L) * stride[i];
        }
        return length;
    }

    public static long lengthOfBuffer(@NonNull int[] shape, @NonNull int[] stride) {
        if (shape == null) {
            throw new NullPointerException("shape is marked @NonNull but is null");
        }
        if (stride == null) {
            throw new NullPointerException("stride is marked @NonNull but is null");
        }
        Preconditions.checkArgument((shape.length == stride.length ? 1 : 0) != 0, (String)"Shape and strides must be same length: shape %s, stride %s", (Object)shape, (Object)stride);
        long length = 1L;
        for (int i = 0; i < shape.length; ++i) {
            length += (long)((shape[i] - 1) * stride[i]);
        }
        return length;
    }

    public static boolean hasDefaultStridesForShape(INDArray input) {
        if (input.rank() == 0) {
            return true;
        }
        if (!Shape.strideDescendingCAscendingF(input)) {
            return false;
        }
        char order = input.ordering();
        long[] defaultStrides = order == 'f' ? ArrayUtil.calcStridesFortran((long[])input.shape()) : ArrayUtil.calcStrides((long[])input.shape());
        return Arrays.equals(input.stride(), defaultStrides);
    }

    public static boolean isS(@NonNull DataType x) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        return x == DataType.UTF8;
    }

    public static boolean isB(@NonNull DataType x) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        return x == DataType.BOOL;
    }

    public static boolean isZ(@NonNull DataType x) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        return !Shape.isR(x) && !Shape.isS(x) && !Shape.isB(x);
    }

    public static boolean isR(@NonNull DataType x) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        return x == DataType.FLOAT || x == DataType.HALF || x == DataType.DOUBLE;
    }

    private static DataType max(@NonNull DataType typeX, @NonNull DataType typeY) {
        if (typeX == null) {
            throw new NullPointerException("typeX is marked @NonNull but is null");
        }
        if (typeY == null) {
            throw new NullPointerException("typeY is marked @NonNull but is null");
        }
        return DataType.values()[Math.max(typeX.ordinal(), typeY.ordinal())];
    }

    public static DataType pickPairwiseDataType(@NonNull DataType typeX, @NonNull Number number) {
        if (typeX == null) {
            throw new NullPointerException("typeX is marked @NonNull but is null");
        }
        if (number == null) {
            throw new NullPointerException("number is marked @NonNull but is null");
        }
        if (!Nd4j.isExperimentalMode()) {
            return typeX;
        }
        if (number instanceof Double) {
            return Shape.pickPairwiseDataType(typeX, DataType.DOUBLE);
        }
        if (number instanceof Float) {
            return Shape.pickPairwiseDataType(typeX, DataType.FLOAT);
        }
        if (number instanceof Long) {
            return Shape.pickPairwiseDataType(typeX, DataType.LONG);
        }
        if (number instanceof Integer) {
            return Shape.pickPairwiseDataType(typeX, DataType.INT);
        }
        if (number instanceof Short) {
            return Shape.pickPairwiseDataType(typeX, DataType.SHORT);
        }
        if (number instanceof Byte) {
            return Shape.pickPairwiseDataType(typeX, DataType.BYTE);
        }
        throw new UnsupportedOperationException("Unknown Number used: [" + number.getClass().getCanonicalName() + "]");
    }

    public static DataType pickPairwiseDataType(@NonNull DataType typeX, @NonNull DataType typeY) {
        if (typeX == null) {
            throw new NullPointerException("typeX is marked @NonNull but is null");
        }
        if (typeY == null) {
            throw new NullPointerException("typeY is marked @NonNull but is null");
        }
        if (!Nd4j.isExperimentalMode()) {
            return typeX;
        }
        if (typeX == typeY) {
            return typeX;
        }
        boolean rX = Shape.isR(typeX);
        boolean rY = Shape.isR(typeY);
        if (rX && !rY) {
            return typeX;
        }
        if (!rX && rY) {
            return typeY;
        }
        if (rX && rY) {
            if (Nd4j.isPrecisionBoostAllowed()) {
                return Shape.max(typeX, typeY);
            }
            return typeX;
        }
        if (!rX && !rY) {
            if (Nd4j.isPrecisionBoostAllowed()) {
                return Shape.max(typeX, typeY);
            }
            return typeX;
        }
        return typeX;
    }

    public static boolean isEmpty(long[] shapeInfo) {
        return ArrayOptionsHelper.arrayType(shapeInfo) == ArrayType.EMPTY;
    }

    public static void assertValidOrder(char order) {
        if (order != 'c' && order != 'f' && order != 'a') {
            throw new IllegalArgumentException("Invalid order arg: must be 'c' or 'f' (or 'a' for vectors), got '" + order + "'");
        }
    }

    public static INDArray ndArrayDimFromInt(int ... dimensions) {
        if (dimensions == null || dimensions.length == 0) {
            return Nd4j.empty(DataType.INT);
        }
        return Nd4j.createFromArray(dimensions);
    }

    public static long[] reductionShape(INDArray x, int[] dimension, boolean newFormat, boolean keepDims) {
        long[] retShape;
        boolean wholeArray;
        boolean bl = wholeArray = Shape.wholeArrayDimension(dimension) || dimension.length == x.rank();
        if (!newFormat) {
            long[] lArray;
            if (wholeArray) {
                long[] lArray2 = new long[2];
                lArray2[0] = 1L;
                lArray = lArray2;
                lArray2[1] = 1L;
            } else {
                lArray = retShape = ArrayUtil.removeIndex((long[])x.shape(), (int[])dimension);
            }
            if (retShape.length == 1) {
                retShape = dimension[0] == 0 ? new long[]{1L, retShape[0]} : new long[]{retShape[0], 1L};
            } else if (retShape.length == 0) {
                retShape = new long[]{1L, 1L};
            }
        } else if (keepDims) {
            retShape = (long[])x.shape().clone();
            if (wholeArray) {
                for (int i = 0; i < retShape.length; ++i) {
                    retShape[i] = 1L;
                }
            } else {
                for (int d : dimension) {
                    retShape[d] = 1L;
                }
            }
        } else {
            retShape = wholeArray ? new long[]{} : ArrayUtil.removeIndex((long[])x.shape(), (int[])dimension);
        }
        return retShape;
    }
}

