/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.util;

import java.util.ArrayList;
import java.util.Arrays;
import org.nd4j.linalg.util.ArrayUtil;

public class Shape {
    public static int[] squeeze(int[] shape, int[] stride) {
        ArrayList<Integer> ret = new ArrayList<Integer>();
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] == 1) continue;
            ret.add(shape[i]);
        }
        return ArrayUtil.toArray(ret);
    }

    public static int[] sizeForAxes(int[] axes, int[] shape) {
        int[] ret = new int[axes.length];
        for (int i = 0; i < axes.length; ++i) {
            ret[i] = shape[axes[i]];
        }
        return ret;
    }

    public static boolean isVector(int[] shape) {
        if (shape.length > 2) {
            return false;
        }
        int len = ArrayUtil.prod(shape);
        return shape[0] == len || shape[1] == len;
    }

    public static boolean isMatrix(int[] shape) {
        if (shape.length != 2) {
            return false;
        }
        return !Shape.isVector(shape);
    }

    public static int[] squeeze(int[] shape) {
        ArrayList<Integer> ret = new ArrayList<Integer>();
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] == 1) continue;
            ret.add(shape[i]);
        }
        return ArrayUtil.toArray(ret);
    }

    public static int nonZeroDimension(int[] shape) {
        if (shape[0] == 1 && shape.length > 1) {
            return shape[1];
        }
        return shape[0];
    }

    public static boolean shapeEquals(int[] shape1, int[] shape2) {
        if (Shape.isColumnVectorShape(shape1) && Shape.isColumnVectorShape(shape2)) {
            return Arrays.equals(shape1, shape2);
        }
        if (Shape.isRowVectorShape(shape1) && Shape.isRowVectorShape(shape2)) {
            int[] shape1Comp = Shape.squeeze(shape1);
            int[] shape2Comp = Shape.squeeze(shape2);
            return Arrays.equals(shape1Comp, shape2Comp);
        }
        return Shape.scalarEquals(shape1, shape2) || Arrays.equals(shape1, shape2);
    }

    public static boolean scalarEquals(int[] shape1, int[] shape2) {
        return shape1.length == 0 ? shape2.length == 1 && shape2[0] == 1 : shape2.length == 0 && shape1.length == 1 && shape1[0] == 1;
    }

    public static boolean isRowVectorShape(int[] shape) {
        return shape.length == 2 && shape[0] == 1 || shape.length == 1;
    }

    public static boolean isColumnVectorShape(int[] shape) {
        return shape.length == 2 && shape[1] == 1;
    }

    public static boolean squeezeEquals(int[] test1, int[] test2) {
        int[] s2;
        int[] s1 = Shape.squeeze(test1);
        return Shape.scalarEquals(s1, s2 = Shape.squeeze(test2)) || Arrays.equals(s1, s2);
    }
}

