/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShapeValidation;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;

public class TensorShape {
    private long[] shape;

    public TensorShape(long[] shape) {
        this.shape = Arrays.copyOf(shape, shape.length);
    }

    public long[] getShape() {
        return Arrays.copyOf(this.shape, this.shape.length);
    }

    public boolean isScalar() {
        return TensorShape.isScalar(this.shape);
    }

    public boolean isLengthOne() {
        return TensorShape.isLengthOne(this.shape);
    }

    public int getRank() {
        return this.shape.length;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        TensorShape that = (TensorShape)o;
        return Arrays.equals(this.shape, that.shape);
    }

    public int hashCode() {
        return Arrays.hashCode(this.shape);
    }

    public static long getLength(long[] shape) {
        long length = 1L;
        for (long dim : shape) {
            length *= dim;
        }
        return length;
    }

    public static int getLengthAsInt(long[] shape) {
        return Ints.checkedCast((long)TensorShape.getLength(shape));
    }

    public static long[] getRowFirstStride(long[] shape) {
        long[] stride = new long[shape.length];
        if (shape.length == 0) {
            return stride;
        }
        stride[stride.length - 1] = 1L;
        int buffer = 1;
        for (int i = stride.length - 2; i >= 0; --i) {
            buffer = (int)((long)buffer * shape[i + 1]);
            stride[i] = buffer;
        }
        return stride;
    }

    public static long getFlatIndex(long[] shape, long[] stride, long ... index) {
        long flatIndex = 0L;
        for (int i = 0; i < shape.length; ++i) {
            if (i >= index.length || index[i] >= shape[i]) {
                throw new IllegalArgumentException("Invalid index " + Arrays.toString(index) + " for shape " + Arrays.toString(shape));
            }
            flatIndex += stride[i] * index[i];
        }
        return flatIndex;
    }

    public static long[] getShapeIndices(long[] shape, long[] stride, long flatIndex) {
        Preconditions.checkArgument((flatIndex >= 0L ? 1 : 0) != 0, (Object)"Flat index must be >= 0 and less than the length of the shape");
        long[] shapeIndices = new long[stride.length];
        long remainder = flatIndex;
        for (int i = 0; i < stride.length; ++i) {
            shapeIndices[i] = remainder / stride[i];
            if (shapeIndices[i] >= shape[i]) {
                throw new IllegalArgumentException("The requested index is out of the bounds of this shape.");
            }
            remainder -= shapeIndices[i] * stride[i];
        }
        return shapeIndices;
    }

    public static boolean isScalar(long[] shape) {
        return shape.length == 0;
    }

    public static boolean isLengthOne(long[] shape) {
        return TensorShape.getLength(shape) == 1L;
    }

    public static long[] concat(long[] shape1, long[] shape2) {
        long[] result = new long[shape1.length + shape2.length];
        System.arraycopy(shape1, 0, result, 0, shape1.length);
        System.arraycopy(shape2, 0, result, shape1.length, shape2.length);
        return result;
    }

    public static int[] dimensionRange(int fromDimension, int toDimension) {
        if (fromDimension > toDimension) {
            throw new IllegalArgumentException("from dimension must be less than to dimension");
        }
        int dimensionCount = toDimension - fromDimension;
        int[] dims = new int[dimensionCount];
        for (int i = 0; i < dimensionCount; ++i) {
            dims[i] = i + fromDimension;
        }
        return dims;
    }

    public static long[] selectDimensions(int from, int to, long[] shape) {
        if (from > to) {
            throw new IllegalArgumentException("to dimension must be less than from");
        }
        long[] newShape = new long[to - from];
        for (int i = 0; i < to - from; ++i) {
            newShape[i] = shape[i + from];
        }
        return newShape;
    }

    public static int[] slideDimension(int from, int to, int rank) {
        int[] dimensionRange = TensorShape.dimensionRange(0, rank);
        ArrayList<Integer> shapeList = new ArrayList<Integer>(Ints.asList((int[])dimensionRange));
        Integer dimLength = (Integer)shapeList.remove(from);
        shapeList.add(to, dimLength);
        return Ints.toArray(shapeList);
    }

    public static long[] shapeDesiredToRankByAppendingOnes(long[] lowRankTensorShape, int desiredRank) {
        return TensorShape.increaseRankByPaddingValue(lowRankTensorShape, desiredRank, true);
    }

    public static long[] shapeToDesiredRankByPrependingOnes(long[] lowRankTensorShape, int desiredRank) {
        return TensorShape.increaseRankByPaddingValue(lowRankTensorShape, desiredRank, false);
    }

    public static long[] calculateShapeForLengthOneBroadcast(long[] shape1, long[] shape2) {
        return shape1.length >= shape2.length ? shape1 : shape2;
    }

    public static long[] getBroadcastResultShape(long[] left, long[] right) {
        long[] shapeOfHighestRank = left.length > right.length ? left : right;
        long[] resultShape = Arrays.copyOf(shapeOfHighestRank, shapeOfHighestRank.length);
        int lowestRank = Math.min(left.length, right.length);
        for (int i = 1; i <= lowestRank; ++i) {
            long lDim = left[left.length - i];
            long rDim = right[right.length - i];
            if (lDim != rDim && lDim != 1L && rDim != 1L) {
                throw new IllegalArgumentException("Shape " + Arrays.toString(left) + " is not broadcastable with shape " + Arrays.toString(right));
            }
            resultShape[resultShape.length - i] = Math.max(lDim, rDim);
        }
        return resultShape;
    }

    private static long[] increaseRankByPaddingValue(long[] lowRankTensorShape, int desiredRank, boolean append) {
        if (lowRankTensorShape.length == desiredRank) {
            return lowRankTensorShape;
        }
        if (lowRankTensorShape.length > desiredRank) {
            throw new IllegalArgumentException("low rank tensor must be rank less than or equal to desired rank");
        }
        long[] paddedShape = new long[desiredRank];
        Arrays.fill(paddedShape, 1L);
        if (append) {
            System.arraycopy(lowRankTensorShape, 0, paddedShape, 0, lowRankTensorShape.length);
        } else {
            System.arraycopy(lowRankTensorShape, 0, paddedShape, paddedShape.length - lowRankTensorShape.length, lowRankTensorShape.length);
        }
        return paddedShape;
    }

    public static int[] setToAbsoluteDimensions(int rank, int[] dimensions) {
        for (int i = 0; i < dimensions.length; ++i) {
            dimensions[i] = TensorShape.getAbsoluteDimension(dimensions[i], rank);
        }
        return dimensions;
    }

    public static long[] removeDimension(int dimension, long[] shape) {
        TensorShapeValidation.checkDimensionExistsInShape(dimension, shape);
        return ArrayUtils.remove((long[])shape, (int)dimension);
    }

    public static int getAbsoluteDimension(int dimension, int rank) {
        if (dimension >= rank || dimension < -rank) {
            throw new IllegalArgumentException("Dimension " + dimension + " is invalid for rank " + rank + " tensor.");
        }
        if (dimension < 0) {
            dimension += rank;
        }
        return dimension;
    }

    public static long[] getReductionResultShape(long[] inputShape, int[] sumOverDimensions) {
        if (inputShape.length > 0) {
            return ArrayUtils.removeAll((long[])inputShape, (int[])sumOverDimensions);
        }
        Preconditions.checkArgument((sumOverDimensions.length == 0 ? 1 : 0) != 0);
        return inputShape;
    }

    public static long[] getPermutedIndices(long[] indices, int ... rearrange) {
        long[] permutedIndices = new long[indices.length];
        for (int i = 0; i < indices.length; ++i) {
            permutedIndices[i] = indices[rearrange[i]];
        }
        return permutedIndices;
    }

    public static int[] invertedPermute(int[] rearrange) {
        int[] inverted = new int[rearrange.length];
        for (int i = 0; i < rearrange.length; ++i) {
            inverted[rearrange[i]] = i;
        }
        return inverted;
    }

    public static long convertFromFlatIndexToPermutedFlatIndex(long fromFlatIndex, long[] shape, long[] stride, long[] permutedShape, long[] permutedStride, int[] rearrange) {
        long[] shapeIndices = TensorShape.getShapeIndices(shape, stride, fromFlatIndex);
        long[] permutedIndex = TensorShape.getPermutedIndices(shapeIndices, rearrange);
        return TensorShape.getFlatIndex(permutedShape, permutedStride, permutedIndex);
    }

    public static long[] getReshapeAllowingWildcard(long[] oldShape, long oldShapeLength, long[] newShape) {
        long newLength = 1L;
        int negativeDimension = -1;
        long[] newShapeCopy = new long[newShape.length];
        System.arraycopy(newShape, 0, newShapeCopy, 0, newShape.length);
        for (int i = 0; i < newShapeCopy.length; ++i) {
            long dimILength = newShapeCopy[i];
            if (dimILength > 0L) {
                newLength *= dimILength;
                continue;
            }
            if (dimILength >= 0L) continue;
            if (negativeDimension >= 0) {
                throw new IllegalArgumentException("Cannot reshape " + Arrays.toString(oldShape) + " to " + Arrays.toString(newShapeCopy));
            }
            negativeDimension = i;
        }
        if (newLength != oldShapeLength || negativeDimension >= 0) {
            if (negativeDimension < 0) {
                throw new IllegalArgumentException("Cannot reshape " + Arrays.toString(oldShape) + " to " + Arrays.toString(newShapeCopy));
            }
            newShapeCopy[negativeDimension] = oldShapeLength / newLength;
        }
        return newShapeCopy;
    }

    public static long[] getConcatResultShape(int dimension, Tensor ... toConcat) {
        long[] lArray;
        Preconditions.checkArgument((toConcat.length > 0 ? 1 : 0) != 0);
        Tensor first = toConcat[0];
        long[] firstShape = first.getShape();
        if (firstShape.length == 0 && dimension != 0) {
            throw new IllegalArgumentException("Cannot concat scalars on dimension " + dimension);
        }
        if (firstShape.length == 0) {
            long[] lArray2 = new long[1];
            lArray = lArray2;
            lArray2[0] = 1L;
        } else {
            lArray = Arrays.copyOf(firstShape, firstShape.length);
        }
        long[] concatShape = lArray;
        for (int i = 1; i < toConcat.length; ++i) {
            Tensor c = toConcat[i];
            long[] cShape = c.getShape();
            for (int dim = 0; dim < concatShape.length; ++dim) {
                if (dim == dimension) {
                    int n = dimension;
                    concatShape[n] = concatShape[n] + (cShape.length == 0 ? 1L : cShape[dimension]);
                    continue;
                }
                if (cShape[dim] == concatShape[dim]) continue;
                throw new IllegalArgumentException("Cannot concat shape " + Arrays.toString(cShape));
            }
        }
        return concatShape;
    }

    public static int[] getPermutationForDimensionToDimensionZero(int dimension, long[] shape) {
        int[] rearrange = new int[shape.length];
        rearrange[0] = dimension;
        for (int i = 1; i < rearrange.length; ++i) {
            rearrange[i] = i > dimension ? i : i - 1;
        }
        return rearrange;
    }

    public static long getBroadcastedFlatIndex(long fromFlatIndex, long[] fromStride, long[] toShape, long[] toStride) {
        int rankDiff = fromStride.length - toStride.length;
        long remainder = fromFlatIndex;
        long toFlatIndex = 0L;
        for (int i = 0; i < fromStride.length; ++i) {
            long fromShapeIndex = remainder / fromStride[i];
            remainder -= fromShapeIndex * fromStride[i];
            if (i < rankDiff) continue;
            long toShapeIndex = fromShapeIndex % toShape[i - rankDiff];
            toFlatIndex += toStride[i - rankDiff] * toShapeIndex;
        }
        return toFlatIndex;
    }

    public static boolean incrementIndexByShape(long[] shape, long[] index, int[] dimensionOrder) {
        for (int i : dimensionOrder) {
            index[i] = (index[i] + 1L) % shape[i];
            if (index[i] == 0L) continue;
            return true;
        }
        return false;
    }
}

