/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.shape;

import com.google.common.primitives.Ints;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.nd4j.bytebuddy.shape.IndexMapper;
import org.nd4j.bytebuddy.shape.ShapeMapper;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.ShapeOffsetResolution;
import org.nd4j.linalg.util.ArrayUtil;

public class Shape {
    private static Map<Integer, IndexMapper> indexMappers = Collections.synchronizedMap(new HashMap());
    private static Map<Integer, IndexMapper> indexMappersC = Collections.synchronizedMap(new HashMap());

    public static INDArray toOffsetZero(INDArray arr) {
        if (arr.offset() < 1 && arr.data().length() == arr.length() || arr instanceof IComplexNDArray && arr.length() * 2 == arr.data().length()) {
            return arr;
        }
        if (arr.isRowVector()) {
            if (arr instanceof IComplexNDArray) {
                IComplexNDArray ret = Nd4j.createComplex(arr.shape());
                for (int i = 0; i < ret.length(); ++i) {
                    ret.putScalar(i, ((IComplexNDArray)arr).getComplex(i));
                }
                return ret;
            }
            INDArray ret = Nd4j.create(arr.shape());
            for (int i = 0; i < ret.length(); ++i) {
                ret.putScalar(i, arr.getDouble(i));
            }
            return ret;
        }
        if (arr instanceof IComplexNDArray) {
            IComplexNDArray ret = Nd4j.createComplex(arr.shape());
            for (int i = 0; i < ret.slices(); ++i) {
                ret.putSlice(i, arr.slice(i));
            }
            return ret;
        }
        INDArray ret = Nd4j.create(arr.shape());
        for (int i = 0; i < ret.slices(); ++i) {
            ret.putSlice(i, arr.slice(i));
        }
        return ret;
    }

    public static INDArray toOffsetZeroCopy(INDArray arr) {
        if (arr.isRowVector()) {
            if (arr instanceof IComplexNDArray) {
                IComplexNDArray ret = Nd4j.createComplex(arr.shape());
                for (int i = 0; i < ret.length(); ++i) {
                    ret.putScalar(i, ((IComplexNDArray)arr).getComplex(i));
                }
                return ret;
            }
            INDArray ret = Nd4j.create(arr.shape());
            for (int i = 0; i < ret.length(); ++i) {
                ret.putScalar(i, arr.getDouble(i));
            }
            return ret;
        }
        if (arr instanceof IComplexNDArray) {
            IComplexNDArray ret = Nd4j.createComplex(arr.shape());
            for (int i = 0; i < ret.slices(); ++i) {
                ret.putSlice(i, arr.slice(i));
            }
            return ret;
        }
        INDArray ret = Nd4j.create(arr.shape());
        for (int i = 0; i < arr.vectorsAlongDimension(0); ++i) {
            ret.vectorAlongDimension(i, 0).assign(arr.vectorAlongDimension(i, 0));
        }
        return ret;
    }

    public static int[] sizeForAxes(int[] axes, int[] shape) {
        int[] ret = new int[shape.length];
        for (int i = 0; i < axes.length; ++i) {
            ret[i] = shape[axes[i]];
        }
        return ret;
    }

    public static boolean isVector(int[] shape) {
        if (shape.length > 2 || shape.length < 1) {
            return false;
        }
        int len = ArrayUtil.prod(shape);
        return shape[0] == len || shape[1] == len;
    }

    public static boolean isMatrix(int[] shape) {
        if (shape.length != 2) {
            return false;
        }
        return !Shape.isVector(shape);
    }

    public static int[] squeeze(int[] shape) {
        if (Shape.isColumnVectorShape(shape)) {
            return shape;
        }
        ArrayList<Integer> ret = new ArrayList<Integer>();
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] == 1) continue;
            ret.add(shape[i]);
        }
        return ArrayUtil.toArray(ret);
    }

    public static int[] nonOneDimensions(int[] dimensions, int[] shape) {
        if (dimensions.length != shape.length) {
            throw new IllegalArgumentException("Dimensions and shape must be the same length");
        }
        ArrayList<Integer> list = new ArrayList<Integer>();
        for (int i = 0; i < dimensions.length; ++i) {
            if (shape[i] == 1) continue;
            list.add(i);
        }
        return Ints.toArray(list);
    }

    public static int[] leadingAndTrailingOnes(int[] original) {
        ArrayList<Integer> ints = new ArrayList<Integer>();
        if (!Shape.isVector(original)) {
            for (int i = 0; i < original.length; ++i) {
                if (original[i] == 1) continue;
                ints.add(original[i]);
            }
            return Ints.toArray(ints);
        }
        return original;
    }

    public static boolean shapeEquals(int[] shape1, int[] shape2) {
        if (Shape.isColumnVectorShape(shape1) && Shape.isColumnVectorShape(shape2)) {
            return Arrays.equals(shape1, shape2);
        }
        if (Shape.isRowVectorShape(shape1) && Shape.isRowVectorShape(shape2)) {
            int[] shape1Comp = Shape.squeeze(shape1);
            int[] shape2Comp = Shape.squeeze(shape2);
            return Arrays.equals(shape1Comp, shape2Comp);
        }
        return Shape.scalarEquals(shape1 = Shape.squeeze(shape1), shape2 = Shape.squeeze(shape2)) || Arrays.equals(shape1, shape2);
    }

    public static boolean scalarEquals(int[] shape1, int[] shape2) {
        return shape1.length == 0 ? shape2.length == 1 && shape2[0] == 1 : shape2.length == 0 && shape1.length == 1 && shape1[0] == 1;
    }

    public static boolean isRowVectorShape(int[] shape) {
        return shape.length == 2 && shape[0] == 1 || shape.length == 1;
    }

    public static boolean isColumnVectorShape(int[] shape) {
        return shape.length == 2 && shape[1] == 1;
    }

    public static int[] createConcatStrides(INDArray ... arrays) {
        int i0;
        int rank = arrays[0].rank();
        for (INDArray arr : arrays) {
            if (arr.rank() == rank) continue;
            throw new IllegalArgumentException("All arrays must have same rank");
        }
        int[] ret = new int[rank];
        for (i0 = 0; i0 < rank; ++i0) {
            ret[i0] = i0;
        }
        for (i0 = 1; i0 < rank; ++i0) {
            int i1;
            int ipos = i0;
            int ax_j0 = ret[i0];
            for (i1 = i0 - 1; i1 >= 0; --i1) {
                boolean ambig = true;
                boolean shouldSwap = false;
                int ax_j1 = ret[i1];
                for (int iarrays = 0; iarrays < arrays.length; ++iarrays) {
                    if (arrays[iarrays].size(ax_j0) == 1 || arrays[iarrays].size(ax_j1) == 1) continue;
                    if (Math.abs(arrays[iarrays].stride(ax_j0)) <= Math.abs(arrays[iarrays].size(ax_j1))) {
                        shouldSwap = false;
                    } else if (ambig) {
                        shouldSwap = true;
                    }
                    ambig = false;
                }
                if (ambig) continue;
                if (!shouldSwap) break;
                ipos = i1;
            }
            if (ipos == i0) continue;
            for (i1 = i0; i1 > ipos; --i1) {
                ret[i1] = ret[i1 - 1];
            }
            ret[ipos] = ax_j0;
        }
        return ret;
    }

    public static INDArray newShapeNoCopy(INDArray arr, int[] newShape, boolean isFOrder) {
        int nk;
        int ni;
        int oi;
        int[] olddims = ArrayUtil.copy(arr.shape());
        int[] oldstrides = ArrayUtil.copy(arr.stride());
        int[] newStrides = new int[newShape.length];
        int oldnd = 0;
        for (oi = 0; oi < arr.rank(); ++oi) {
            if (arr.size(oi) == 1) continue;
            olddims[oldnd] = arr.size(oi);
            oldstrides[oldnd] = arr.stride(oi);
            ++oldnd;
        }
        int np = 1;
        for (ni = 0; ni < newShape.length; ++ni) {
            np *= newShape[ni];
        }
        int op = 1;
        for (oi = 0; oi < oldnd; ++oi) {
            op *= olddims[oi];
        }
        if (np != op) {
            return null;
        }
        if (np == 0) {
            return null;
        }
        oi = 0;
        int oj = 1;
        ni = 0;
        int nj = 1;
        while (ni < newShape.length && oi < oldnd) {
            np = newShape[ni];
            op = olddims[oi];
            while (np != op) {
                if (np < op) {
                    np *= newShape[nj++];
                    continue;
                }
                op *= olddims[oj++];
            }
            for (int ok = oi; ok < oj - 1; ++ok) {
                if (!(isFOrder ? oldstrides[ok + 1] != olddims[ok] * oldstrides[ok] : oldstrides[ok] != olddims[ok + 1] * oldstrides[ok + 1])) continue;
                return null;
            }
            if (isFOrder) {
                newStrides[ni] = oldstrides[oi];
                for (nk = ni + 1; nk < nj; ++nk) {
                    newStrides[nk] = newStrides[nk - 1] * newShape[nk - 1];
                }
            } else {
                newStrides[nj - 1] = oldstrides[oj - 1];
                for (nk = nj - 1; nk > ni; --nk) {
                    newStrides[nk - 1] = newStrides[nk] * newShape[nk];
                }
            }
            ni = nj++;
            oi = oj++;
        }
        int last_stride = ni >= 1 ? newStrides[ni - 1] : arr.elementStride();
        if (isFOrder && ni >= 1) {
            last_stride *= newShape[ni - 1];
        }
        for (nk = ni; nk < newShape.length; ++nk) {
            newStrides[nk] = last_stride;
        }
        if (arr instanceof IComplexNDArray) {
            return Nd4j.createComplex(arr.data(), newShape, newStrides, arr.offset());
        }
        return Nd4j.create(arr.data(), newShape, newStrides, arr.offset());
    }

    public static char getOrder(int[] shape, int[] stride, int elementStride) {
        int dim;
        int i;
        boolean cContiguous = true;
        boolean isFortran = true;
        int sd = 1;
        for (i = shape.length - 1; i >= 0; --i) {
            dim = shape[i];
            if (stride[i] != sd) {
                cContiguous = false;
                break;
            }
            if (dim == 0) break;
            sd *= dim;
        }
        sd = elementStride;
        for (i = 0; i < shape.length; ++i) {
            dim = shape[i];
            if (stride[i] != sd) {
                isFortran = false;
            }
            if (dim == 0) break;
            sd *= dim;
        }
        if (isFortran && cContiguous) {
            return 'a';
        }
        if (isFortran && !cContiguous) {
            return 'f';
        }
        if (!isFortran && !cContiguous) {
            return 'c';
        }
        return 'c';
    }

    public static char getOrder(INDArray arr) {
        return Shape.getOrder(arr.shape(), arr.stride(), arr.elementStride());
    }

    public static int sub2Ind(int[] shape, int[] indices) {
        int index = 0;
        int shift = 1;
        for (int i = 0; i < shape.length; ++i) {
            index += shift * indices[i];
            shift *= shape[i];
        }
        return index;
    }

    public static int[] ind2sub(int[] shape, int index, int numIndices) {
        IndexMapper mapper = indexMappers.get(shape.length);
        if (mapper == null) {
            mapper = ShapeMapper.getInd2SubInstance((char)'f', (int)shape.length);
            indexMappers.put(index, mapper);
            mapper = ShapeMapper.getInd2SubInstance((char)'c', (int)shape.length);
            indexMappersC.put(index, mapper);
        }
        return mapper.ind2sub(shape, index, numIndices, 'f');
    }

    public static int[] ind2sub(int[] shape, int index) {
        return Shape.ind2sub(shape, index, ArrayUtil.prod(shape));
    }

    public static int[] ind2sub(INDArray arr, int index) {
        return Shape.ind2sub(arr.shape(), index, ArrayUtil.prod(arr.shape()));
    }

    public static int[] ind2subC(int[] shape, int index, int numIndices) {
        IndexMapper mapper = indexMappersC.get(shape.length);
        if (mapper == null) {
            mapper = ShapeMapper.getInd2SubInstance((char)'f', (int)shape.length);
            indexMappers.put(index, mapper);
            mapper = ShapeMapper.getInd2SubInstance((char)'c', (int)shape.length);
            indexMappersC.put(index, mapper);
        }
        return mapper.ind2sub(shape, index, numIndices, 'c');
    }

    public static int[] ind2subC(int[] shape, int index) {
        return Shape.ind2subC(shape, index, ArrayUtil.prod(shape));
    }

    public static int[] ind2subC(INDArray arr, int index) {
        return Shape.ind2subC(arr.shape(), index, ArrayUtil.prod(arr.shape()));
    }

    public static int offsetFor(INDArray arr, int[] indexes) {
        ShapeOffsetResolution resolution = new ShapeOffsetResolution(arr);
        resolution.exec(Shape.toIndexes(indexes));
        return resolution.getOffset();
    }

    public static int offsetFor(INDArray arr, int index) {
        int[] indexes = arr.ordering() == 'c' ? Shape.ind2subC(arr, index) : Shape.ind2sub(arr, index);
        return Shape.offsetFor(arr, indexes);
    }

    public static void assertShapeLessThan(int[] shape, int[] lessThan) {
        if (shape.length != lessThan.length) {
            throw new IllegalArgumentException("Shape length must be == less than length");
        }
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] < lessThan[i]) continue;
            throw new IllegalStateException("Shape[" + i + "] should be less than lessThan[" + i + "]");
        }
    }

    public static int[] moveOnesToEnd(int[] shape) {
        ArrayList<Integer> nonOnes = new ArrayList<Integer>();
        ArrayList<Integer> ones = new ArrayList<Integer>();
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] == 1) {
                ones.add(i);
                continue;
            }
            nonOnes.add(i);
        }
        return Ints.concat((int[][])new int[][]{Ints.toArray(nonOnes), Ints.toArray(ones)});
    }

    public static INDArrayIndex[] toIndexes(int[] indices) {
        INDArrayIndex[] ret = new INDArrayIndex[indices.length];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = new NDArrayIndex(indices[i]);
        }
        return ret;
    }

    public static int[] newStrides(int[] strides, int newLength, INDArrayIndex[] indexes) {
        if (strides.length > newLength) {
            int[] newStrides = new int[strides.length - 1];
            for (int i = 0; i < newStrides.length; ++i) {
                newStrides[i] = strides[i + 1];
            }
            strides = newStrides;
        }
        return strides;
    }

    public static int[] newOffsets(int[] offsets, int newLength, INDArrayIndex[] indexes) {
        if (offsets.length > newLength) {
            int[] newOffsets = new int[offsets.length - 1];
            for (int i = 0; i < newOffsets.length; ++i) {
                newOffsets[i] = offsets[i + 1];
            }
            offsets = newOffsets;
        }
        return offsets;
    }

    public static int[] squeezeOffsets(int[] shape, int[] offsets) {
        ArrayList<Integer> squeezeIndices = new ArrayList<Integer>();
        for (int i = 0; i < shape.length; ++i) {
            if (offsets[i] != 0) continue;
            squeezeIndices.add(i);
        }
        int[] ret = ArrayUtil.removeIndex(offsets, Ints.toArray(squeezeIndices));
        int delta = Math.abs(ret.length - shape.length);
        if (delta == 0) {
            return ret;
        }
        if (ret.length > shape.length) {
            throw new IllegalStateException("Unable to squeeze offsets");
        }
        int[] retOffsets = new int[shape.length];
        System.arraycopy(ret, 0, retOffsets, 0, ret.length);
        return retOffsets;
    }

    public static boolean squeezeEquals(int[] test1, int[] test2) {
        int[] s2;
        int[] s1 = Shape.squeeze(test1);
        return Shape.scalarEquals(s1, s2 = Shape.squeeze(test2)) || Arrays.equals(s1, s2);
    }

    static {
        for (int i = 0; i < 10; ++i) {
            indexMappersC.put(i, ShapeMapper.getInd2SubInstance((char)'c', (int)i));
            indexMappers.put(i, ShapeMapper.getInd2SubInstance((char)'f', (int)i));
        }
    }
}

