package io.improbable.keanu.tensor;

import com.google.common.base.Preconditions;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.ArrayUtils;

/* loaded from: input_file:io/improbable/keanu/tensor/TensorShapeValidation.class */
public class TensorShapeValidation {
    private TensorShapeValidation() {
    }

    public static void checkTensorsMatchNonLengthOneShapeOrAreLengthOne(long[] jArr, long[]... jArr2) {
        Set<TensorShape> nonLengthOneShapes = getNonLengthOneShapes(jArr2);
        if (nonLengthOneShapes.isEmpty()) {
            return;
        }
        if (nonLengthOneShapes.size() > 1) {
            throw new IllegalArgumentException("More than a single non length one shape");
        }
        long[] shape = nonLengthOneShapes.iterator().next().getShape();
        if (!Arrays.equals(shape, jArr)) {
            throw new IllegalArgumentException("Proposed shape " + Arrays.toString(jArr) + " does not match other non length one shapes " + Arrays.toString(shape));
        }
    }

    public static void checkDimensionExistsInShape(int i, long[] jArr) {
        if (i >= jArr.length) {
            throw new IllegalArgumentException(String.format("Dimension %d does not exist in tensor of rank %d", Integer.valueOf(i), Integer.valueOf(jArr.length)));
        }
    }

    public static void checkTensorsAreScalar(String str, long[]... jArr) {
        if (!getNonScalarShapes(jArr).isEmpty()) {
            throw new IllegalArgumentException(str);
        }
    }

    public static long[] checkHasOneNonLengthOneShapeOrAllLengthOne(long[]... jArr) {
        Set<TensorShape> nonLengthOneShapes = getNonLengthOneShapes(jArr);
        List<TensorShape> lengthOneShapesSortedByRank = getLengthOneShapesSortedByRank(jArr);
        if (nonLengthOneShapes.isEmpty()) {
            if (!lengthOneShapesSortedByRank.isEmpty()) {
                return lengthOneShapesSortedByRank.get(0).getShape();
            }
        } else if (nonLengthOneShapes.size() == 1) {
            return nonLengthOneShapes.iterator().next().getShape();
        }
        throw new IllegalArgumentException("Shapes must match or be length one but were: " + ((String) Arrays.stream(jArr).map(Arrays::toString).collect(Collectors.joining(","))));
    }

    public static long[] checkIsBroadcastable(long[] jArr, long[] jArr2) {
        return TensorShape.getBroadcastResultShape(jArr, jArr2);
    }

    public static boolean isBroadcastable(long[] jArr, long[] jArr2) {
        try {
            TensorShape.getBroadcastResultShape(jArr, jArr2);
            return true;
        } catch (IllegalArgumentException e) {
            return false;
        }
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [long[], long[][]] */
    public static long[] checkTernaryConditionShapeIsValid(long[] jArr, long[] jArr2, long[] jArr3) {
        Preconditions.checkArgument(Arrays.equals(jArr2, jArr3), "Then shape " + Arrays.toString(jArr2) + " must match else shape " + Arrays.toString(jArr3));
        return checkHasOneNonLengthOneShapeOrAllLengthOne(new long[]{jArr, jArr2, jArr3});
    }

    public static void checkShapeIsSquareMatrix(long[] jArr) {
        if (jArr.length != 2) {
            throw new IllegalArgumentException("Input tensor must be a matrix");
        }
        if (jArr[0] != jArr[1]) {
            throw new IllegalArgumentException("Input matrix must be square");
        }
    }

    private static Set<TensorShape> getNonLengthOneShapes(long[]... jArr) {
        return (Set) Arrays.stream(jArr).map(TensorShape::new).filter(tensorShape -> {
            return !tensorShape.isLengthOne();
        }).collect(Collectors.toSet());
    }

    private static List<TensorShape> getLengthOneShapesSortedByRank(long[]... jArr) {
        return (List) Arrays.stream(jArr).map(TensorShape::new).filter((v0) -> {
            return v0.isLengthOne();
        }).sorted(Comparator.comparingInt(obj -> {
            return ((TensorShape) obj).getRank();
        }).reversed()).collect(Collectors.toList());
    }

    private static Set<TensorShape> getNonScalarShapes(long[]... jArr) {
        return (Set) Arrays.stream(jArr).map(TensorShape::new).filter(tensorShape -> {
            return !tensorShape.isScalar();
        }).collect(Collectors.toSet());
    }

    public static void checkShapesMatch(long[] jArr, long[] jArr2) {
        if (!Arrays.equals(jArr, jArr2)) {
            throw new IllegalArgumentException(String.format("Expected shape %s but got %s", Arrays.toString(jArr2), Arrays.toString(jArr)));
        }
    }

    public static long[] checkAllShapesMatch(long[]... jArr) {
        return checkAllShapesMatch((Stream<long[]>) Arrays.stream(jArr), (Optional<String>) Optional.empty());
    }

    public static long[] checkAllShapesMatch(String str, long[]... jArr) {
        return checkAllShapesMatch((Stream<long[]>) Arrays.stream(jArr), (Optional<String>) Optional.of(str));
    }

    public static long[] checkAllShapesMatch(String str, Collection<long[]> collection) {
        return checkAllShapesMatch(collection.stream(), (Optional<String>) Optional.of(str));
    }

    public static long[] checkAllShapesMatch(Collection<long[]> collection) {
        return checkAllShapesMatch(collection.stream(), (Optional<String>) Optional.empty());
    }

    private static long[] checkAllShapesMatch(Stream<long[]> stream, Optional<String> optional) {
        Set set = (Set) stream.map(TensorShape::new).collect(Collectors.toSet());
        if (set.size() != 1) {
            throw new IllegalArgumentException(optional.orElse("Shapes must match"));
        }
        return ((TensorShape) set.iterator().next()).getShape();
    }

    public static long[] checkShapesCanBeConcatenated(int i, long[]... jArr) {
        long[] copyOf = Arrays.copyOf(jArr[0], jArr[0].length);
        for (int i2 = 1; i2 < jArr.length; i2++) {
            int length = jArr[i2].length;
            if (length <= i) {
                throw new IllegalArgumentException(String.format("Cannot concat operand %d because dimension %d is greater than or equal to its rank %d", Integer.valueOf(i2), Integer.valueOf(i), Integer.valueOf(length)));
            }
            if (length != copyOf.length) {
                throw new IllegalArgumentException("Cannot concat shapes of different ranks");
            }
            for (int i3 = 0; i3 < length; i3++) {
                if (i3 == i) {
                    int i4 = i3;
                    copyOf[i4] = copyOf[i4] + jArr[i2][i3];
                } else if (jArr[i2][i3] != copyOf[i3]) {
                    throw new IllegalArgumentException("Cannot concat mismatched shapes");
                }
            }
        }
        return copyOf;
    }

    public static void checkIndexIsValid(long[] jArr, long... jArr2) {
        if (jArr.length != jArr2.length) {
            throw new IllegalArgumentException("Length of desired index " + Arrays.toString(jArr2) + " must match the length of the shape " + Arrays.toString(jArr));
        }
        for (int i = 0; i < jArr2.length; i++) {
            if (jArr2[i] >= jArr[i]) {
                throw new IllegalArgumentException("Invalid index " + Arrays.toString(jArr2) + " for shape " + Arrays.toString(jArr));
            }
        }
    }

    public static long[] getTensorMultiplyResultShape(long[] jArr, long[] jArr2, int[] iArr, int[] iArr2) {
        if (iArr.length != iArr2.length) {
            throw new IllegalArgumentException("Tensor multiply must match dimension lengths " + toStringArgs(jArr, jArr2, iArr, iArr2));
        }
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] >= jArr.length || iArr[i] < 0) {
                throw new IllegalArgumentException("Left dimensions " + Arrays.toString(iArr) + " is invalid for left shape " + Arrays.toString(jArr));
            }
            if (iArr2[i] >= jArr2.length || iArr2[i] < 0) {
                throw new IllegalArgumentException("Right dimensions " + Arrays.toString(iArr2) + " is invalid for right shape " + Arrays.toString(jArr2));
            }
            if (jArr[iArr[i]] != jArr2[iArr2[i]]) {
                throw new IllegalArgumentException("Cannot tensor multiply dimension " + i + " for " + toStringArgs(jArr, jArr2, iArr, iArr2));
            }
        }
        return TensorShape.concat(ArrayUtils.removeAll(jArr, iArr), ArrayUtils.removeAll(jArr2, iArr2));
    }

    private static String toStringArgs(long[] jArr, long[] jArr2, int[] iArr, int[] iArr2) {
        return "left shape: " + Arrays.toString(jArr) + " right shape: " + Arrays.toString(jArr2) + " on left dimensions " + Arrays.toString(iArr) + " and right dimensions " + Arrays.toString(iArr2);
    }

    public static long[] getMatrixMultiplicationResultingShape(long[] jArr, long[] jArr2) {
        if (jArr.length != 2 || jArr2.length != 2) {
            throw new IllegalArgumentException("Matrix multiply must be used on matrices");
        }
        if (jArr[1] != jArr2[0]) {
            throw new IllegalArgumentException("Can not multiply matrices of shapes " + Arrays.toString(jArr) + " X " + Arrays.toString(jArr2));
        }
        return new long[]{jArr[0], jArr2[1]};
    }
}
