package io.improbable.keanu.tensor.jvm;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.buffer.JVMBuffer;
import io.improbable.keanu.tensor.buffer.JVMBuffer.PrimitiveArrayWrapper;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;

/* loaded from: input_file:io/improbable/keanu/tensor/jvm/JVMTensor.class */
public abstract class JVMTensor<T, TENSOR extends Tensor<T, TENSOR>, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> implements Tensor<T, TENSOR> {
    protected B buffer;
    protected long[] shape;
    protected long[] stride;

    /* JADX INFO: Access modifiers changed from: protected */
    public JVMTensor(B b, long[] jArr, long[] jArr2) {
        this.buffer = b;
        this.shape = jArr;
        this.stride = jArr2;
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public int getRank() {
        return this.shape.length;
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public long[] getShape() {
        return Arrays.copyOf(this.shape, this.shape.length);
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public long[] getStride() {
        return Arrays.copyOf(this.stride, this.stride.length);
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public long getLength() {
        return this.buffer.getLength();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.tensor.Tensor
    public TENSOR get(BooleanTensor booleanTensor) {
        ArrayList arrayList = new ArrayList();
        Tensor.FlattenedView<Boolean> flattenedView = booleanTensor.getFlattenedView();
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= booleanTensor.getLength()) {
                break;
            }
            if (flattenedView.get(j2).booleanValue()) {
                arrayList.add(Long.valueOf(j2));
            }
            j = j2 + 1;
        }
        JVMBuffer.PrimitiveArrayWrapper createNew = getFactory().createNew(arrayList.size());
        for (int i = 0; i < createNew.getLength(); i++) {
            createNew.set(this.buffer.get(((Long) arrayList.get(i)).longValue()), i);
        }
        return (TENSOR) create(createNew, new long[]{createNew.getLength()}, new long[]{1});
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public TENSOR diag() {
        return createFromResultWrapper(diag(this.shape.length, this.shape, this.buffer, getFactory()));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> ResultWrapper<T, B> diag(int i, long[] jArr, B b, JVMBuffer.ArrayWrapperFactory<T, B> arrayWrapperFactory) {
        long[] jArr2;
        JVMBuffer.PrimitiveArrayWrapper primitiveArrayWrapper;
        if (i == 1) {
            long length = b.getLength();
            JVMBuffer.PrimitiveArrayWrapper createNew = arrayWrapperFactory.createNew(length * length);
            long j = 0;
            while (true) {
                long j2 = j;
                if (j2 >= length) {
                    break;
                }
                createNew.set(b.get(j2), (j2 * length) + j2);
                j = j2 + 1;
            }
            jArr2 = new long[]{length, length};
            primitiveArrayWrapper = createNew;
        } else {
            if (i != 2 || jArr[0] != jArr[1]) {
                throw new IllegalArgumentException("Diag is only valid for vectors or square matrices");
            }
            long j3 = jArr[0];
            JVMBuffer.PrimitiveArrayWrapper createNew2 = arrayWrapperFactory.createNew(j3);
            long j4 = 0;
            while (true) {
                long j5 = j4;
                if (j5 >= j3) {
                    break;
                }
                createNew2.set(b.get((j5 * j3) + j5), j5);
                j4 = j5 + 1;
            }
            jArr2 = new long[]{j3};
            primitiveArrayWrapper = createNew2;
        }
        return new ResultWrapper<>(primitiveArrayWrapper, jArr2, TensorShape.getRowFirstStride(jArr2));
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public TENSOR permute(int... iArr) {
        return createFromResultWrapper(permute(getFactory(), this.buffer, this.shape, this.stride, iArr));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> ResultWrapper<T, B> permute(JVMBuffer.ArrayWrapperFactory<T, B> arrayWrapperFactory, B b, long[] jArr, long[] jArr2, int... iArr) {
        Preconditions.checkArgument(iArr.length == jArr.length);
        long[] permutedIndices = TensorShape.getPermutedIndices(jArr, iArr);
        long[] rowFirstStride = TensorShape.getRowFirstStride(permutedIndices);
        B createNew = arrayWrapperFactory.createNew(b.getLength());
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= b.getLength()) {
                return new ResultWrapper<>(createNew, permutedIndices, rowFirstStride);
            }
            createNew.set(b.get(j2), TensorShape.convertFromFlatIndexToPermutedFlatIndex(j2, jArr, jArr2, permutedIndices, rowFirstStride, iArr));
            j = j2 + 1;
        }
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public List<TENSOR> split(int i, long... jArr) {
        return (List) split(getFactory(), this.buffer, this.shape, this.stride, i, jArr).stream().map(this::createFromResultWrapper).collect(Collectors.toList());
    }

    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> List<ResultWrapper<T, B>> split(JVMBuffer.ArrayWrapperFactory<T, B> arrayWrapperFactory, B b, long[] jArr, long[] jArr2, int i, long... jArr3) {
        int absoluteDimension = TensorShape.getAbsoluteDimension(i, jArr.length);
        if (absoluteDimension < 0 || absoluteDimension >= jArr.length) {
            throw new IllegalArgumentException("Invalid dimension to split on " + absoluteDimension);
        }
        int[] slideDimension = TensorShape.slideDimension(absoluteDimension, 0, jArr.length);
        int[] slideDimension2 = TensorShape.slideDimension(0, absoluteDimension, jArr.length);
        ResultWrapper permute = permute(arrayWrapperFactory, b, jArr, jArr2, slideDimension);
        ArrayList arrayList = new ArrayList();
        long j = 0;
        long j2 = 0;
        for (long j3 : jArr3) {
            long[] copyOf = Arrays.copyOf(jArr, jArr.length);
            long j4 = j3 - j;
            if (j4 > jArr[absoluteDimension] || j4 <= 0) {
                throw new IllegalArgumentException("Invalid index to split on " + j3 + " at " + absoluteDimension + " for tensor of shape " + Arrays.toString(jArr));
            }
            copyOf[absoluteDimension] = j4;
            long length = TensorShape.getLength(copyOf);
            B createNew = arrayWrapperFactory.createNew(length);
            createNew.copyFrom2(permute.outputBuffer, j2, 0L, length);
            long[] permutedIndices = TensorShape.getPermutedIndices(copyOf, slideDimension);
            arrayList.add(permute(arrayWrapperFactory, createNew, permutedIndices, TensorShape.getRowFirstStride(permutedIndices), slideDimension2));
            j = j3;
            j2 += createNew.getLength();
        }
        return arrayList;
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public TENSOR slice(int i, long j) {
        return createFromResultWrapper(slice(getFactory(), this.buffer, new DimensionIndexMapper(this.shape, this.stride, i, j)));
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public TENSOR slice(Slicer slicer) {
        return createFromResultWrapper(slice(getFactory(), this.buffer, new SlicerIndexMapper(slicer, this.shape, this.stride)));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> ResultWrapper<T, B> slice(JVMBuffer.ArrayWrapperFactory<T, B> arrayWrapperFactory, B b, IndexMapper indexMapper) {
        long[] resultShape = indexMapper.getResultShape();
        long[] resultStride = indexMapper.getResultStride();
        B createNew = arrayWrapperFactory.createNew(TensorShape.getLength(resultShape));
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= createNew.getLength()) {
                return new ResultWrapper<>(createNew, resultShape, resultStride);
            }
            createNew.set(b.get(indexMapper.getSourceIndexFromResultIndex(j2)), j2);
            j = j2 + 1;
        }
    }

    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> ResultWrapper<T, B> concat(JVMBuffer.ArrayWrapperFactory<T, B> arrayWrapperFactory, Tensor[] tensorArr, int i, List<B> list) {
        long[] concatResultShape = TensorShape.getConcatResultShape(i, tensorArr);
        if (!(i != 0)) {
            return new ResultWrapper<>(concatOnDimensionZero(arrayWrapperFactory, concatResultShape, list), concatResultShape, TensorShape.getRowFirstStride(concatResultShape));
        }
        int[] permutationForDimensionToDimensionZero = TensorShape.getPermutationForDimensionToDimensionZero(i, concatResultShape);
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < list.size(); i2++) {
            arrayList.add(permute(arrayWrapperFactory, list.get(i2), tensorArr[i2].getShape(), tensorArr[i2].getStride(), permutationForDimensionToDimensionZero).outputBuffer);
        }
        long[] permutedIndices = TensorShape.getPermutedIndices(concatResultShape, permutationForDimensionToDimensionZero);
        return permute(arrayWrapperFactory, concatOnDimensionZero(arrayWrapperFactory, permutedIndices, arrayList), permutedIndices, TensorShape.getRowFirstStride(permutedIndices), TensorShape.invertedPermute(permutationForDimensionToDimensionZero));
    }

    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> B concatOnDimensionZero(JVMBuffer.ArrayWrapperFactory<T, B> arrayWrapperFactory, long[] jArr, List<B> list) {
        B createNew = arrayWrapperFactory.createNew(TensorShape.getLength(jArr));
        int i = 0;
        for (int i2 = 0; i2 < list.size(); i2++) {
            B b = list.get(i2);
            createNew.copyFrom2(b, 0L, i, b.getLength());
            i = (int) (i + b.getLength());
        }
        return createNew;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public TENSOR broadcastableBinaryOpWithAutoBroadcast(BiFunction<T, T, T> biFunction, JVMTensor<T, TENSOR, B> jVMTensor) {
        ResultWrapper broadcastIfNeeded = JVMTensorBroadcast.broadcastIfNeeded(getFactory(), this.buffer, this.shape, this.stride, this.buffer.getLength(), jVMTensor.buffer, jVMTensor.shape, jVMTensor.stride, jVMTensor.buffer.getLength(), biFunction, true);
        return set(broadcastIfNeeded.outputBuffer, broadcastIfNeeded.outputShape, broadcastIfNeeded.outputStride);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.tensor.Tensor
    public TENSOR reshape(long... jArr) {
        long[] reshapeAllowingWildcard = TensorShape.getReshapeAllowingWildcard(this.shape, this.buffer.getLength(), jArr);
        return (TENSOR) create(this.buffer.copy(), reshapeAllowingWildcard, TensorShape.getRowFirstStride(reshapeAllowingWildcard));
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public TENSOR broadcast(long... jArr) {
        long length = TensorShape.getLength(jArr);
        long[] rowFirstStride = TensorShape.getRowFirstStride(jArr);
        B createNew = getFactory().createNew(length);
        JVMTensorBroadcast.broadcast(this.buffer, this.shape, this.stride, createNew, rowFirstStride);
        return create(createNew, jArr, rowFirstStride);
    }

    public IntegerTensor argCompare(BiFunction<T, T, Boolean> biFunction, int i) {
        return argCompare(getFactory(), this.buffer, biFunction, this.shape, this.stride, i);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> IntegerTensor argCompare(JVMBuffer.ArrayWrapperFactory<T, B> arrayWrapperFactory, B b, BiFunction<T, T, Boolean> biFunction, long[] jArr, long[] jArr2, int i) {
        if (i >= jArr.length) {
            throw new IllegalArgumentException("Cannot take arg max of axis " + i + " on a " + jArr.length + " rank tensor.");
        }
        B b2 = permute(arrayWrapperFactory, b, jArr, jArr2, TensorShape.getPermutationForDimensionToDimensionZero(i, jArr)).outputBuffer;
        int length = (int) (b.getLength() / jArr[i]);
        B createNew = arrayWrapperFactory.createNew(length);
        int[] iArr = new int[length];
        Arrays.fill(iArr, -1);
        for (int i2 = 0; i2 < b2.getLength(); i2++) {
            int i3 = i2 % length;
            Object obj = b2.get(i2);
            if (iArr[i3] < 0 || ((Boolean) biFunction.apply(obj, createNew.get(i3))).booleanValue()) {
                createNew.set(obj, i3);
                iArr[i3] = i2 / length;
            }
        }
        return IntegerTensor.create(iArr, ArrayUtils.remove(jArr, i));
    }

    public int argCompare(BiFunction<T, T, Boolean> biFunction) {
        return argCompare(this.buffer, biFunction);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> int argCompare(B b, BiFunction<T, T, Boolean> biFunction) {
        Object obj = null;
        int i = -1;
        for (int i2 = 0; i2 < b.getLength(); i2++) {
            Object obj2 = b.get(i2);
            if (i2 == 0 || ((Boolean) biFunction.apply(obj2, obj)).booleanValue()) {
                obj = obj2;
                i = i2;
            }
        }
        return i;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public BooleanTensor isApply(Function<T, Boolean> function) {
        boolean[] zArr = new boolean[Ints.checkedCast(this.buffer.getLength())];
        for (int i = 0; i < this.buffer.getLength(); i++) {
            zArr[i] = ((Boolean) function.apply(this.buffer.get(i))).booleanValue();
        }
        return BooleanTensor.create(zArr, Arrays.copyOf(this.shape, this.shape.length));
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        JVMTensor jVMTensor = (JVMTensor) obj;
        return Arrays.equals(this.shape, jVMTensor.shape) && this.buffer.equals(jVMTensor.buffer);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (this.buffer.getLength() > 20) {
            sb.append(Arrays.toString(this.buffer.asArray(0L, 10L)));
            sb.append("...");
            sb.append(Arrays.toString(this.buffer.asArray(this.buffer.getLength() - 10, this.buffer.getLength())));
        } else {
            sb.append(Arrays.toString(this.buffer.asArray()));
        }
        return "{\nshape = " + Arrays.toString(this.shape) + "\ndata = " + sb.toString() + "\n}";
    }

    public int hashCode() {
        return (31 * Arrays.hashCode(this.shape)) + this.buffer.hashCode();
    }

    private TENSOR createFromResultWrapper(ResultWrapper<T, B> resultWrapper) {
        return create(resultWrapper.outputBuffer, resultWrapper.outputShape, resultWrapper.outputStride);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract TENSOR create(B b, long[] jArr, long[] jArr2);

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract TENSOR set(B b, long[] jArr, long[] jArr2);

    protected abstract JVMBuffer.ArrayWrapperFactory<T, B> getFactory();
}
