/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor.ndj4;

import io.improbable.keanu.tensor.TensorShape;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

public class INDArrayExtensions {
    public static INDArray castToInteger(INDArray tensor, boolean duplicate) {
        INDArray tensorToDropFractionOn = duplicate ? tensor.dup() : tensor;
        INDArray sign = Transforms.sign((INDArray)tensorToDropFractionOn);
        Transforms.floor((INDArray)Transforms.abs((INDArray)tensorToDropFractionOn, (boolean)false), (boolean)false).muli(sign);
        return tensorToDropFractionOn;
    }

    public static List<INDArray> split(INDArray tensor, int dimension, long ... splitAtIndices) {
        long[] shape = tensor.shape();
        if ((dimension = TensorShape.getAbsoluteDimension(dimension, tensor.rank())) < 0 || dimension >= shape.length) {
            throw new IllegalArgumentException("Invalid dimension to split on " + dimension);
        }
        Nd4j.getCompressor().autoDecompress(tensor);
        ArrayList<INDArray> splits = new ArrayList<INDArray>();
        long previousSplitIndex = 0L;
        for (int i = 0; i < splitAtIndices.length; ++i) {
            INDArrayIndex[] indices = new INDArrayIndex[tensor.rank()];
            if (previousSplitIndex == splitAtIndices[i]) {
                throw new IllegalArgumentException("Invalid index to split on " + splitAtIndices[i] + " at dimension " + dimension + " for tensor of shape " + Arrays.toString(shape));
            }
            indices[dimension] = NDArrayIndex.interval((long)previousSplitIndex, (long)splitAtIndices[i]);
            previousSplitIndex = splitAtIndices[i];
            for (int j = 0; j < tensor.rank(); ++j) {
                if (j == dimension) continue;
                indices[j] = NDArrayIndex.all();
            }
            splits.add(tensor.get(indices));
        }
        return splits;
    }

    public static INDArray cumProd(INDArray array, int dimension) {
        if (array.isScalar() || array.isEmpty()) {
            return array;
        }
        if (array.isVector()) {
            double s = 1.0;
            int i = 0;
            while ((long)i < array.length()) {
                array.putScalar((long)i, s *= array.getDouble((long)i));
                ++i;
            }
        } else {
            if (dimension == Integer.MAX_VALUE) {
                INDArray flattened = array.ravel();
                double prevVal = flattened.getDouble(0L);
                int i = 1;
                while ((long)i < flattened.length()) {
                    double d = prevVal * flattened.getDouble((long)i);
                    flattened.putScalar((long)i, d);
                    prevVal = d;
                    ++i;
                }
                return flattened;
            }
            int i = 0;
            while ((long)i < array.vectorsAlongDimension(dimension)) {
                INDArray vec = array.vectorAlongDimension(i, dimension);
                INDArrayExtensions.cumProd(vec, 0);
                ++i;
            }
        }
        return array;
    }
}

