package io.improbable.keanu.tensor.ndj4;

import io.improbable.keanu.tensor.TensorShape;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:io/improbable/keanu/tensor/ndj4/INDArrayExtensions.class */
public class INDArrayExtensions {
    public static INDArray castToInteger(INDArray iNDArray, boolean z) {
        INDArray dup = z ? iNDArray.dup() : iNDArray;
        Transforms.floor(Transforms.abs(dup, false), false).muli(Transforms.sign(dup));
        return dup;
    }

    public static List<INDArray> split(INDArray iNDArray, int i, long... jArr) {
        long[] shape = iNDArray.shape();
        int absoluteDimension = TensorShape.getAbsoluteDimension(i, iNDArray.rank());
        if (absoluteDimension < 0 || absoluteDimension >= shape.length) {
            throw new IllegalArgumentException("Invalid dimension to split on " + absoluteDimension);
        }
        Nd4j.getCompressor().autoDecompress(iNDArray);
        ArrayList arrayList = new ArrayList();
        long j = 0;
        for (int i2 = 0; i2 < jArr.length; i2++) {
            INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[iNDArray.rank()];
            if (j == jArr[i2]) {
                throw new IllegalArgumentException("Invalid index to split on " + jArr[i2] + " at dimension " + absoluteDimension + " for tensor of shape " + Arrays.toString(shape));
            }
            iNDArrayIndexArr[absoluteDimension] = NDArrayIndex.interval(j, jArr[i2]);
            j = jArr[i2];
            for (int i3 = 0; i3 < iNDArray.rank(); i3++) {
                if (i3 != absoluteDimension) {
                    iNDArrayIndexArr[i3] = NDArrayIndex.all();
                }
            }
            arrayList.add(iNDArray.get(iNDArrayIndexArr));
        }
        return arrayList;
    }

    public static INDArray cumProd(INDArray iNDArray, int i) {
        if (iNDArray.isScalar() || iNDArray.isEmpty()) {
            return iNDArray;
        }
        if (iNDArray.isVector()) {
            double d = 1.0d;
            for (int i2 = 0; i2 < iNDArray.length(); i2++) {
                d *= iNDArray.getDouble(i2);
                iNDArray.putScalar(i2, d);
            }
        } else {
            if (i == Integer.MAX_VALUE) {
                INDArray ravel = iNDArray.ravel();
                double d2 = ravel.getDouble(0L);
                for (int i3 = 1; i3 < ravel.length(); i3++) {
                    double d3 = d2 * ravel.getDouble(i3);
                    ravel.putScalar(i3, d3);
                    d2 = d3;
                }
                return ravel;
            }
            for (int i4 = 0; i4 < iNDArray.vectorsAlongDimension(i); i4++) {
                cumProd(iNDArray.vectorAlongDimension(i4, i), 0);
            }
        }
        return iNDArray;
    }
}
