package io.improbable.keanu.tensor;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;

/* loaded from: input_file:io/improbable/keanu/tensor/TensorShape.class */
public class TensorShape {
    private long[] shape;

    public TensorShape(long[] jArr) {
        this.shape = Arrays.copyOf(jArr, jArr.length);
    }

    public long[] getShape() {
        return Arrays.copyOf(this.shape, this.shape.length);
    }

    public boolean isScalar() {
        return isScalar(this.shape);
    }

    public boolean isLengthOne() {
        return isLengthOne(this.shape);
    }

    public int getRank() {
        return this.shape.length;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        return Arrays.equals(this.shape, ((TensorShape) obj).shape);
    }

    public int hashCode() {
        return Arrays.hashCode(this.shape);
    }

    public static long getLength(long[] jArr) {
        long j = 1;
        for (long j2 : jArr) {
            j *= j2;
        }
        return j;
    }

    public static int getLengthAsInt(long[] jArr) {
        return Ints.checkedCast(getLength(jArr));
    }

    public static long[] getRowFirstStride(long[] jArr) {
        long[] jArr2 = new long[jArr.length];
        if (jArr.length == 0) {
            return jArr2;
        }
        jArr2[jArr2.length - 1] = 1;
        int i = 1;
        for (int length = jArr2.length - 2; length >= 0; length--) {
            i = (int) (i * jArr[length + 1]);
            jArr2[length] = i;
        }
        return jArr2;
    }

    public static long getFlatIndex(long[] jArr, long[] jArr2, long... jArr3) {
        long j = 0;
        for (int i = 0; i < jArr.length; i++) {
            if (i >= jArr3.length || jArr3[i] >= jArr[i]) {
                throw new IllegalArgumentException("Invalid index " + Arrays.toString(jArr3) + " for shape " + Arrays.toString(jArr));
            }
            j += jArr2[i] * jArr3[i];
        }
        return j;
    }

    public static long[] getShapeIndices(long[] jArr, long[] jArr2, long j) {
        Preconditions.checkArgument(j >= 0, "Flat index must be >= 0 and less than the length of the shape");
        long[] jArr3 = new long[jArr2.length];
        long j2 = j;
        for (int i = 0; i < jArr2.length; i++) {
            jArr3[i] = j2 / jArr2[i];
            if (jArr3[i] >= jArr[i]) {
                throw new IllegalArgumentException("The requested index is out of the bounds of this shape.");
            }
            j2 -= jArr3[i] * jArr2[i];
        }
        return jArr3;
    }

    public static boolean isScalar(long[] jArr) {
        return jArr.length == 0;
    }

    public static boolean isLengthOne(long[] jArr) {
        return getLength(jArr) == 1;
    }

    public static long[] concat(long[] jArr, long[] jArr2) {
        long[] jArr3 = new long[jArr.length + jArr2.length];
        System.arraycopy(jArr, 0, jArr3, 0, jArr.length);
        System.arraycopy(jArr2, 0, jArr3, jArr.length, jArr2.length);
        return jArr3;
    }

    public static int[] dimensionRange(int i, int i2) {
        if (i > i2) {
            throw new IllegalArgumentException("from dimension must be less than to dimension");
        }
        int i3 = i2 - i;
        int[] iArr = new int[i3];
        for (int i4 = 0; i4 < i3; i4++) {
            iArr[i4] = i4 + i;
        }
        return iArr;
    }

    public static long[] selectDimensions(int i, int i2, long[] jArr) {
        if (i > i2) {
            throw new IllegalArgumentException("to dimension must be less than from");
        }
        long[] jArr2 = new long[i2 - i];
        for (int i3 = 0; i3 < i2 - i; i3++) {
            jArr2[i3] = jArr[i3 + i];
        }
        return jArr2;
    }

    public static int[] slideDimension(int i, int i2, int i3) {
        ArrayList arrayList = new ArrayList(Ints.asList(dimensionRange(0, i3)));
        arrayList.add(i2, (Integer) arrayList.remove(i));
        return Ints.toArray(arrayList);
    }

    public static long[] shapeDesiredToRankByAppendingOnes(long[] jArr, int i) {
        return increaseRankByPaddingValue(jArr, i, true);
    }

    public static long[] shapeToDesiredRankByPrependingOnes(long[] jArr, int i) {
        return increaseRankByPaddingValue(jArr, i, false);
    }

    public static long[] calculateShapeForLengthOneBroadcast(long[] jArr, long[] jArr2) {
        return jArr.length >= jArr2.length ? jArr : jArr2;
    }

    public static long[] getBroadcastResultShape(long[] jArr, long[] jArr2) {
        long[] jArr3 = jArr.length > jArr2.length ? jArr : jArr2;
        long[] copyOf = Arrays.copyOf(jArr3, jArr3.length);
        int min = Math.min(jArr.length, jArr2.length);
        for (int i = 1; i <= min; i++) {
            long j = jArr[jArr.length - i];
            long j2 = jArr2[jArr2.length - i];
            if (j != j2 && j != 1 && j2 != 1) {
                throw new IllegalArgumentException("Shape " + Arrays.toString(jArr) + " is not broadcastable with shape " + Arrays.toString(jArr2));
            }
            copyOf[copyOf.length - i] = Math.max(j, j2);
        }
        return copyOf;
    }

    private static long[] increaseRankByPaddingValue(long[] jArr, int i, boolean z) {
        if (jArr.length == i) {
            return jArr;
        }
        if (jArr.length > i) {
            throw new IllegalArgumentException("low rank tensor must be rank less than or equal to desired rank");
        }
        long[] jArr2 = new long[i];
        Arrays.fill(jArr2, 1L);
        if (z) {
            System.arraycopy(jArr, 0, jArr2, 0, jArr.length);
        } else {
            System.arraycopy(jArr, 0, jArr2, jArr2.length - jArr.length, jArr.length);
        }
        return jArr2;
    }

    public static int[] setToAbsoluteDimensions(int i, int[] iArr) {
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr[i2] = getAbsoluteDimension(iArr[i2], i);
        }
        return iArr;
    }

    public static long[] removeDimension(int i, long[] jArr) {
        TensorShapeValidation.checkDimensionExistsInShape(i, jArr);
        return ArrayUtils.remove(jArr, i);
    }

    public static int getAbsoluteDimension(int i, int i2) {
        if (i >= i2 || i < (-i2)) {
            throw new IllegalArgumentException("Dimension " + i + " is invalid for rank " + i2 + " tensor.");
        }
        if (i < 0) {
            i += i2;
        }
        return i;
    }

    public static long[] getReductionResultShape(long[] jArr, int[] iArr) {
        if (jArr.length > 0) {
            return ArrayUtils.removeAll(jArr, iArr);
        }
        Preconditions.checkArgument(iArr.length == 0);
        return jArr;
    }

    public static long[] getPermutedIndices(long[] jArr, int... iArr) {
        long[] jArr2 = new long[jArr.length];
        for (int i = 0; i < jArr.length; i++) {
            jArr2[i] = jArr[iArr[i]];
        }
        return jArr2;
    }

    public static int[] invertedPermute(int[] iArr) {
        int[] iArr2 = new int[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr2[iArr[i]] = i;
        }
        return iArr2;
    }

    public static long convertFromFlatIndexToPermutedFlatIndex(long j, long[] jArr, long[] jArr2, long[] jArr3, long[] jArr4, int[] iArr) {
        return getFlatIndex(jArr3, jArr4, getPermutedIndices(getShapeIndices(jArr, jArr2, j), iArr));
    }

    public static long[] getReshapeAllowingWildcard(long[] jArr, long j, long[] jArr2) {
        long j2 = 1;
        int i = -1;
        long[] jArr3 = new long[jArr2.length];
        System.arraycopy(jArr2, 0, jArr3, 0, jArr2.length);
        for (int i2 = 0; i2 < jArr3.length; i2++) {
            long j3 = jArr3[i2];
            if (j3 > 0) {
                j2 *= j3;
            } else if (j3 >= 0) {
                continue;
            } else {
                if (i >= 0) {
                    throw new IllegalArgumentException("Cannot reshape " + Arrays.toString(jArr) + " to " + Arrays.toString(jArr3));
                }
                i = i2;
            }
        }
        if (j2 != j || i >= 0) {
            if (i < 0) {
                throw new IllegalArgumentException("Cannot reshape " + Arrays.toString(jArr) + " to " + Arrays.toString(jArr3));
            }
            jArr3[i] = j / j2;
        }
        return jArr3;
    }

    public static long[] getConcatResultShape(int i, Tensor... tensorArr) {
        Preconditions.checkArgument(tensorArr.length > 0);
        long[] shape = tensorArr[0].getShape();
        if (shape.length == 0 && i != 0) {
            throw new IllegalArgumentException("Cannot concat scalars on dimension " + i);
        }
        long[] copyOf = shape.length == 0 ? new long[]{1} : Arrays.copyOf(shape, shape.length);
        for (int i2 = 1; i2 < tensorArr.length; i2++) {
            long[] shape2 = tensorArr[i2].getShape();
            for (int i3 = 0; i3 < copyOf.length; i3++) {
                if (i3 == i) {
                    copyOf[i] = copyOf[i] + (shape2.length == 0 ? 1L : shape2[i]);
                } else if (shape2[i3] != copyOf[i3]) {
                    throw new IllegalArgumentException("Cannot concat shape " + Arrays.toString(shape2));
                }
            }
        }
        return copyOf;
    }

    public static int[] getPermutationForDimensionToDimensionZero(int i, long[] jArr) {
        int[] iArr = new int[jArr.length];
        iArr[0] = i;
        for (int i2 = 1; i2 < iArr.length; i2++) {
            if (i2 > i) {
                iArr[i2] = i2;
            } else {
                iArr[i2] = i2 - 1;
            }
        }
        return iArr;
    }

    public static long getBroadcastedFlatIndex(long j, long[] jArr, long[] jArr2, long[] jArr3) {
        int length = jArr.length - jArr3.length;
        long j2 = j;
        long j3 = 0;
        for (int i = 0; i < jArr.length; i++) {
            long j4 = j2 / jArr[i];
            j2 -= j4 * jArr[i];
            if (i >= length) {
                j3 += jArr3[i - length] * (j4 % jArr2[i - length]);
            }
        }
        return j3;
    }

    public static boolean incrementIndexByShape(long[] jArr, long[] jArr2, int[] iArr) {
        for (int i : iArr) {
            jArr2[i] = (jArr2[i] + 1) % jArr[i];
            if (jArr2[i] != 0) {
                return true;
            }
        }
        return false;
    }
}
