package io.improbable.keanu.tensor.jvm;

import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.buffer.JVMBuffer;
import io.improbable.keanu.tensor.buffer.PrimitiveNumberWrapper;
import java.lang.Number;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.commons.lang3.ArrayUtils;

/* loaded from: input_file:io/improbable/keanu/tensor/jvm/JVMNumberTensor.class */
public abstract class JVMNumberTensor<T extends Number, TENSOR extends NumberTensor<T, TENSOR>, B extends PrimitiveNumberWrapper<T, B>> extends JVMTensor<T, TENSOR, B> implements NumberTensor<T, TENSOR> {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/improbable/keanu/tensor/jvm/JVMNumberTensor$BufferOp.class */
    public interface BufferOp<T extends Number, B extends PrimitiveNumberWrapper<T, B>> {
        void apply(PrimitiveNumberWrapper<T, B> primitiveNumberWrapper, long j, T t);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public JVMNumberTensor(B b, long[] jArr, long[] jArr2) {
        super(b, jArr, jArr2);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public T sum() {
        return (T) ((PrimitiveNumberWrapper) this.buffer).sum();
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public T product() {
        return (T) ((PrimitiveNumberWrapper) this.buffer).product();
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR sum(int... iArr) {
        return reduceOverDimensions((v0, v1, v2) -> {
            v0.plus(v1, v2);
        }, (v0, v1) -> {
            return v0.zeroes(v1);
        }, (v0) -> {
            return v0.sum();
        }, iArr);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR product(int... iArr) {
        return reduceOverDimensions((v0, v1, v2) -> {
            v0.times(v1, v2);
        }, (v0, v1) -> {
            return v0.ones(v1);
        }, (v0) -> {
            return v0.product();
        }, iArr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private TENSOR reduceOverDimensions(BufferOp<T, B> bufferOp, BiFunction<JVMBuffer.PrimitiveNumberWrapperFactory<T, B>, Long, B> biFunction, Function<B, T> function, int... iArr) {
        TensorShape.setToAbsoluteDimensions(this.shape.length, iArr);
        if (isScalar() || iArr.length == 0) {
            return (TENSOR) duplicate();
        }
        if (isVector()) {
            return (TENSOR) create((PrimitiveNumberWrapper) getFactory().createNew((JVMBuffer.PrimitiveNumberWrapperFactory<T, B>) function.apply(this.buffer)), new long[0], new long[0]);
        }
        long[] reductionResultShape = TensorShape.getReductionResultShape(this.shape, iArr);
        long[] rowFirstStride = TensorShape.getRowFirstStride(reductionResultShape);
        B apply = biFunction.apply(getFactory(), Long.valueOf(TensorShape.getLength(reductionResultShape)));
        for (int i = 0; i < ((PrimitiveNumberWrapper) this.buffer).getLength(); i++) {
            bufferOp.apply(apply, TensorShape.getFlatIndex(reductionResultShape, rowFirstStride, ArrayUtils.removeAll(TensorShape.getShapeIndices(this.shape, this.stride, i), iArr)), (Number) ((PrimitiveNumberWrapper) this.buffer).get(i));
        }
        return (TENSOR) create(apply, reductionResultShape, rowFirstStride);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR cumSumInPlace(int i) {
        return cumulativeInPlace((v0, v1, v2) -> {
            v0.plus(v1, v2);
        }, i);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR cumProdInPlace(int i) {
        return cumulativeInPlace((v0, v1, v2) -> {
            v0.times(v1, v2);
        }, i);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v31, types: [java.lang.Number] */
    private TENSOR cumulativeInPlace(BufferOp<T, B> bufferOp, int i) {
        int absoluteDimension = TensorShape.getAbsoluteDimension(i, this.shape.length);
        TensorShapeValidation.checkDimensionExistsInShape(absoluteDimension, this.shape);
        int[] remove = ArrayUtils.remove(TensorShape.dimensionRange(0, this.shape.length), absoluteDimension);
        long[] jArr = new long[this.shape.length];
        do {
            T t = null;
            long j = 0;
            while (true) {
                long j2 = j;
                if (j2 >= this.shape[absoluteDimension]) {
                    break;
                }
                jArr[absoluteDimension] = j2;
                long flatIndex = TensorShape.getFlatIndex(this.shape, this.stride, jArr);
                if (j2 > 0) {
                    bufferOp.apply((PrimitiveNumberWrapper) this.buffer, flatIndex, t);
                }
                t = (Number) ((PrimitiveNumberWrapper) this.buffer).get(flatIndex);
                j = j2 + 1;
            }
        } while (TensorShape.incrementIndexByShape(this.shape, jArr, remove));
        return (TENSOR) set(this.buffer, this.shape, this.stride);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.improbable.keanu.tensor.jvm.JVMTensor
    public abstract JVMBuffer.PrimitiveNumberWrapperFactory<T, B> getFactory();
}
