/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor.jvm;

import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.buffer.JVMBuffer;
import io.improbable.keanu.tensor.buffer.PrimitiveNumberWrapper;
import io.improbable.keanu.tensor.jvm.JVMTensor;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.commons.lang3.ArrayUtils;

public abstract class JVMNumberTensor<T extends Number, TENSOR extends NumberTensor<T, TENSOR>, B extends PrimitiveNumberWrapper<T, B>>
extends JVMTensor<T, TENSOR, B>
implements NumberTensor<T, TENSOR> {
    protected JVMNumberTensor(B buffer, long[] shape, long[] stride) {
        super(buffer, shape, stride);
    }

    @Override
    public T sum() {
        return ((PrimitiveNumberWrapper)this.buffer).sum();
    }

    @Override
    public T product() {
        return ((PrimitiveNumberWrapper)this.buffer).product();
    }

    @Override
    public TENSOR sum(int ... overDimensions) {
        return this.reduceOverDimensions(PrimitiveNumberWrapper::plus, JVMBuffer.PrimitiveNumberWrapperFactory::zeroes, PrimitiveNumberWrapper::sum, overDimensions);
    }

    @Override
    public TENSOR product(int ... overDimensions) {
        return this.reduceOverDimensions(PrimitiveNumberWrapper::times, JVMBuffer.PrimitiveNumberWrapperFactory::ones, PrimitiveNumberWrapper::product, overDimensions);
    }

    private TENSOR reduceOverDimensions(BufferOp<T, B> combine, BiFunction<JVMBuffer.PrimitiveNumberWrapperFactory<T, B>, Long, B> init, Function<B, T> totalReduction, int ... overDimensions) {
        TensorShape.setToAbsoluteDimensions(this.shape.length, overDimensions);
        if (this.isScalar() || overDimensions.length == 0) {
            return (TENSOR)((NumberTensor)this.duplicate());
        }
        if (this.isVector()) {
            PrimitiveNumberWrapper aNew = (PrimitiveNumberWrapper)this.getFactory().createNew(totalReduction.apply(this.buffer));
            return (TENSOR)((NumberTensor)this.create(aNew, new long[0], new long[0]));
        }
        long[] resultShape = TensorShape.getReductionResultShape(this.shape, overDimensions);
        long[] resultStride = TensorShape.getRowFirstStride(resultShape);
        PrimitiveNumberWrapper newBuffer = (PrimitiveNumberWrapper)init.apply((JVMBuffer.PrimitiveNumberWrapperFactory<JVMBuffer.ArrayWrapperFactory, B>)this.getFactory(), TensorShape.getLength(resultShape));
        int i = 0;
        while ((long)i < ((PrimitiveNumberWrapper)this.buffer).getLength()) {
            long[] shapeIndices = ArrayUtils.removeAll((long[])TensorShape.getShapeIndices(this.shape, this.stride, i), (int[])overDimensions);
            long j = TensorShape.getFlatIndex(resultShape, resultStride, shapeIndices);
            combine.apply(newBuffer, j, (Number)((PrimitiveNumberWrapper)this.buffer).get(i));
            ++i;
        }
        return (TENSOR)((NumberTensor)this.create(newBuffer, resultShape, resultStride));
    }

    @Override
    public TENSOR cumSumInPlace(int requestedDimension) {
        return this.cumulativeInPlace(PrimitiveNumberWrapper::plus, requestedDimension);
    }

    @Override
    public TENSOR cumProdInPlace(int requestedDimension) {
        return this.cumulativeInPlace(PrimitiveNumberWrapper::times, requestedDimension);
    }

    private TENSOR cumulativeInPlace(BufferOp<T, B> combine, int requestedDimension) {
        int dimension = TensorShape.getAbsoluteDimension(requestedDimension, this.shape.length);
        TensorShapeValidation.checkDimensionExistsInShape(dimension, this.shape);
        int[] dimensionOrder = ArrayUtils.remove((int[])TensorShape.dimensionRange(0, this.shape.length), (int)dimension);
        long[] index = new long[this.shape.length];
        do {
            Number result = null;
            for (long i = 0L; i < this.shape[dimension]; ++i) {
                index[dimension] = i;
                long j = TensorShape.getFlatIndex(this.shape, this.stride, index);
                if (i > 0L) {
                    combine.apply((PrimitiveNumberWrapper)this.buffer, j, result);
                }
                result = (Number)((PrimitiveNumberWrapper)this.buffer).get(j);
            }
        } while (TensorShape.incrementIndexByShape(this.shape, index, dimensionOrder));
        return (TENSOR)((NumberTensor)this.set(this.buffer, this.shape, this.stride));
    }

    @Override
    protected abstract JVMBuffer.PrimitiveNumberWrapperFactory<T, B> getFactory();

    static interface BufferOp<T extends Number, B extends PrimitiveNumberWrapper<T, B>> {
        public void apply(PrimitiveNumberWrapper<T, B> var1, long var2, T var4);
    }
}

