/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor.dbl;

import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import io.improbable.keanu.tensor.FloatingPointTensor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class TensorMulByMatrixMul {
    public static <T extends Number, TENSOR extends FloatingPointTensor<T, TENSOR>> TENSOR tensorMmul(TENSOR left, TENSOR right, int[] dimsLeft, int[] dimsRight) {
        long[] leftShape = left.getShape();
        long[] rightShape = right.getShape();
        TensorMulByMatrixMul.validateTensorMmul(leftShape, rightShape, dimsLeft, dimsRight);
        List<Integer> leftDimsKept = TensorMulByMatrixMul.getKeptDimensions(leftShape.length, dimsLeft);
        List<Integer> rightDimsKept = TensorMulByMatrixMul.getKeptDimensions(rightShape.length, dimsRight);
        int[] leftDimsPermuted = Ints.concat((int[][])new int[][]{Ints.toArray(leftDimsKept), dimsLeft});
        int[] rightDimsPermuted = Ints.concat((int[][])new int[][]{dimsRight, Ints.toArray(rightDimsKept)});
        long dimsLength = TensorMulByMatrixMul.calculateDimensionsLength(leftShape, dimsLeft);
        long[] leftTensorAsMatrixShape = new long[]{-1L, dimsLength};
        long[] rightTensorAsMatrixShape = new long[]{dimsLength, -1L};
        FloatingPointTensor leftTensorAsMatrix = (FloatingPointTensor)((FloatingPointTensor)left.permute(leftDimsPermuted)).reshape(leftTensorAsMatrixShape);
        FloatingPointTensor rightTensorAsMatrix = (FloatingPointTensor)((FloatingPointTensor)right.permute(rightDimsPermuted)).reshape(rightTensorAsMatrixShape);
        FloatingPointTensor resultAsMatrix = leftTensorAsMatrix.matrixMultiply(rightTensorAsMatrix);
        long[] leftKeptShape = TensorMulByMatrixMul.getKeptShape(leftShape, leftDimsKept);
        long[] rightKeptShape = TensorMulByMatrixMul.getKeptShape(rightShape, rightDimsKept);
        long[] resultShape = Longs.concat((long[][])new long[][]{leftKeptShape, rightKeptShape});
        return (TENSOR)((FloatingPointTensor)resultAsMatrix.reshape(resultShape));
    }

    private static void validateTensorMmul(long[] leftShape, long[] rightShape, int[] dimsLeft, int[] dimsRight) {
        int validationLength = Math.min(dimsLeft.length, dimsRight.length);
        for (int i = 0; i < validationLength; ++i) {
            if (leftShape.length <= dimsLeft[i]) {
                throw new IllegalArgumentException("Invalid left dimension " + dimsLeft[i] + " for shape " + Arrays.toString(leftShape));
            }
            if (rightShape.length <= dimsRight[i]) {
                throw new IllegalArgumentException("Invalid right dimension " + dimsRight[i] + " for shape " + Arrays.toString(rightShape));
            }
            if (leftShape[dimsLeft[i]] != rightShape[dimsRight[i]]) {
                throw new IllegalArgumentException("Size of the given axes at each dimension must be the same size.");
            }
            if (dimsLeft[i] < 0) {
                int n = i;
                dimsLeft[n] = dimsLeft[n] + leftShape.length;
            }
            if (dimsRight[i] >= 0) continue;
            int n = i;
            dimsRight[n] = dimsRight[n] + rightShape.length;
        }
    }

    private static List<Integer> getKeptDimensions(int shapeLength, int[] dims) {
        ArrayList<Integer> result = new ArrayList<Integer>();
        for (int i = 0; i < shapeLength; ++i) {
            if (Ints.contains((int[])dims, (int)i)) continue;
            result.add(i);
        }
        return result;
    }

    private static long[] getKeptShape(long[] shape, List<Integer> keptDimensions) {
        long[] keptShape = Longs.toArray(keptDimensions);
        for (int i = 0; i < keptShape.length; ++i) {
            keptShape[i] = shape[Ints.checkedCast((long)keptShape[i])];
        }
        return keptShape;
    }

    private static long calculateDimensionsLength(long[] shape, int[] dims) {
        long length = 1L;
        int aLength = Math.min(shape.length, dims.length);
        for (int i = 0; i < aLength; ++i) {
            length *= shape[dims[i]];
        }
        return length;
    }
}

