package io.improbable.keanu.tensor.dbl;

import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import io.improbable.keanu.tensor.FloatingPointTensor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:io/improbable/keanu/tensor/dbl/TensorMulByMatrixMul.class */
public class TensorMulByMatrixMul {
    /* JADX WARN: Type inference failed for: r0v12, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r0v15, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r0v41, types: [long[], long[][]] */
    public static <T extends Number, TENSOR extends FloatingPointTensor<T, TENSOR>> TENSOR tensorMmul(TENSOR tensor, TENSOR tensor2, int[] iArr, int[] iArr2) {
        long[] shape = tensor.getShape();
        long[] shape2 = tensor2.getShape();
        validateTensorMmul(shape, shape2, iArr, iArr2);
        List<Integer> keptDimensions = getKeptDimensions(shape.length, iArr);
        List<Integer> keptDimensions2 = getKeptDimensions(shape2.length, iArr2);
        int[] concat = Ints.concat((int[][]) new int[]{Ints.toArray(keptDimensions), iArr});
        int[] concat2 = Ints.concat((int[][]) new int[]{iArr2, Ints.toArray(keptDimensions2)});
        long calculateDimensionsLength = calculateDimensionsLength(shape, iArr);
        return (TENSOR) ((FloatingPointTensor) ((FloatingPointTensor) ((FloatingPointTensor) tensor.permute(concat)).reshape(-1, calculateDimensionsLength)).matrixMultiply((FloatingPointTensor) ((FloatingPointTensor) tensor2.permute(concat2)).reshape(calculateDimensionsLength, -1))).reshape(Longs.concat((long[][]) new long[]{getKeptShape(shape, keptDimensions), getKeptShape(shape2, keptDimensions2)}));
    }

    private static void validateTensorMmul(long[] jArr, long[] jArr2, int[] iArr, int[] iArr2) {
        int min = Math.min(iArr.length, iArr2.length);
        for (int i = 0; i < min; i++) {
            if (jArr.length <= iArr[i]) {
                throw new IllegalArgumentException("Invalid left dimension " + iArr[i] + " for shape " + Arrays.toString(jArr));
            }
            if (jArr2.length <= iArr2[i]) {
                throw new IllegalArgumentException("Invalid right dimension " + iArr2[i] + " for shape " + Arrays.toString(jArr2));
            }
            if (jArr[iArr[i]] != jArr2[iArr2[i]]) {
                throw new IllegalArgumentException("Size of the given axes at each dimension must be the same size.");
            }
            if (iArr[i] < 0) {
                int i2 = i;
                iArr[i2] = iArr[i2] + jArr.length;
            }
            if (iArr2[i] < 0) {
                int i3 = i;
                iArr2[i3] = iArr2[i3] + jArr2.length;
            }
        }
    }

    private static List<Integer> getKeptDimensions(int i, int[] iArr) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            if (!Ints.contains(iArr, i2)) {
                arrayList.add(Integer.valueOf(i2));
            }
        }
        return arrayList;
    }

    private static long[] getKeptShape(long[] jArr, List<Integer> list) {
        long[] array = Longs.toArray(list);
        for (int i = 0; i < array.length; i++) {
            array[i] = jArr[Ints.checkedCast(array[i])];
        }
        return array;
    }

    private static long calculateDimensionsLength(long[] jArr, int[] iArr) {
        long j = 1;
        int min = Math.min(jArr.length, iArr.length);
        for (int i = 0; i < min; i++) {
            j *= jArr[iArr[i]];
        }
        return j;
    }
}
