package io.improbable.keanu.tensor;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Longs;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.generic.GenericTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.tensor.jvm.Slicer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;

/* loaded from: input_file:io/improbable/keanu/tensor/Tensor.class */
public interface Tensor<N, T extends Tensor<N, T>> {
    public static final long[] SCALAR_SHAPE = new long[0];
    public static final long[] SCALAR_STRIDE = new long[0];
    public static final long[] ONE_BY_ONE_SHAPE = {1, 1};

    /* loaded from: input_file:io/improbable/keanu/tensor/Tensor$FlattenedView.class */
    public interface FlattenedView<N> {
        long size();

        N get(long j);

        N getOrScalar(long j);

        void set(long j, N n);
    }

    /* JADX WARN: Multi-variable type inference failed */
    static <DATA, TENSOR extends Tensor<DATA, TENSOR>> TENSOR scalar(DATA data) {
        return data instanceof Double ? DoubleTensor.scalar(((Double) data).doubleValue()) : data instanceof Integer ? IntegerTensor.scalar(((Integer) data).intValue()) : data instanceof Boolean ? BooleanTensor.scalar(((Boolean) data).booleanValue()) : GenericTensor.scalar(data);
    }

    /* JADX WARN: Multi-variable type inference failed */
    static <DATA, TENSOR extends Tensor<DATA, TENSOR>> TENSOR createFilled(DATA data, long[] jArr) {
        return data instanceof Double ? DoubleTensor.create(((Double) data).doubleValue(), jArr) : data instanceof Integer ? IntegerTensor.create(((Integer) data).intValue(), jArr) : data instanceof Boolean ? BooleanTensor.create(((Boolean) data).booleanValue(), jArr) : GenericTensor.createFilled(data, jArr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    static <DATA, TENSOR extends Tensor<DATA, TENSOR>> TENSOR create(DATA[] dataArr, long[] jArr) {
        return dataArr instanceof Double[] ? DoubleTensor.create(ArrayUtils.toPrimitive((Double[]) dataArr), jArr) : dataArr instanceof Integer[] ? IntegerTensor.create(ArrayUtils.toPrimitive((Integer[]) dataArr), jArr) : dataArr instanceof Boolean[] ? BooleanTensor.create(ArrayUtils.toPrimitive((Boolean[]) dataArr), jArr) : GenericTensor.create((Object[]) dataArr, jArr);
    }

    static BooleanTensor elementwiseEquals(Tensor tensor, Tensor tensor2) {
        if (!tensor.hasSameShapeAs(tensor2)) {
            throw new IllegalArgumentException(String.format("Cannot compare tensors of different shapes %s and %s", Arrays.toString(tensor.getShape()), Arrays.toString(tensor2.getShape())));
        }
        Object[] asFlatArray = tensor.asFlatArray();
        Object[] asFlatArray2 = tensor2.asFlatArray();
        boolean[] zArr = new boolean[asFlatArray.length];
        for (int i = 0; i < asFlatArray.length; i++) {
            zArr[i] = asFlatArray[i].equals(asFlatArray2[i]);
        }
        long[] shape = tensor.getShape();
        return BooleanTensor.create(zArr, Arrays.copyOf(shape, shape.length));
    }

    int getRank();

    long[] getShape();

    long[] getStride();

    long getLength();

    default N getValue(long... jArr) {
        return jArr.length == 1 ? getFlattenedView().get(jArr[0]) : getFlattenedView().get(TensorShape.getFlatIndex(getShape(), getStride(), jArr));
    }

    T get(BooleanTensor booleanTensor);

    default void setValue(N n, long... jArr) {
        if (jArr.length == 1) {
            getFlattenedView().set(jArr[0], n);
        } else {
            getFlattenedView().set(TensorShape.getFlatIndex(getShape(), getStride(), jArr), n);
        }
    }

    default N scalar() {
        if (getLength() > 1) {
            throw new IllegalArgumentException("Not a scalar");
        }
        return getValue(0);
    }

    T duplicate();

    T slice(int i, long j);

    default T slice(String str) {
        return slice(Slicer.fromString(str));
    }

    T slice(Slicer slicer);

    T take(long... jArr);

    List<T> split(int i, long... jArr);

    default List<T> sliceAlongDimension(int i, long j, long j2) {
        ArrayList arrayList = new ArrayList();
        long j3 = j;
        while (true) {
            long j4 = j3;
            if (j4 >= j2) {
                return arrayList;
            }
            arrayList.add(slice(i, j4));
            j3 = j4 + 1;
        }
    }

    T diag();

    default T transpose() {
        Preconditions.checkArgument(getRank() == 2, "Can only transpose rank 2. Use permute(...) for higher rank transpose.");
        return permute(1, 0);
    }

    N[] asFlatArray();

    T reshape(long... jArr);

    default T squeeze() {
        long[] shape = getShape();
        ArrayList arrayList = new ArrayList();
        for (long j : shape) {
            if (j > 1) {
                arrayList.add(Long.valueOf(j));
            }
        }
        return reshape(Longs.toArray(arrayList));
    }

    default T expandDims(int i) {
        return reshape(ArrayUtils.insert(i, getShape(), new long[]{1}));
    }

    default T moveAxis(int i, int i2) {
        int[] dimensionRange = TensorShape.dimensionRange(0, getRank());
        int absoluteDimension = TensorShape.getAbsoluteDimension(i, dimensionRange.length);
        return permute(ArrayUtils.insert(TensorShape.getAbsoluteDimension(i2, dimensionRange.length), ArrayUtils.remove(dimensionRange, absoluteDimension), new int[]{absoluteDimension}));
    }

    default T swapAxis(int i, int i2) {
        int[] dimensionRange = TensorShape.dimensionRange(0, getRank());
        int absoluteDimension = TensorShape.getAbsoluteDimension(i, dimensionRange.length);
        int absoluteDimension2 = TensorShape.getAbsoluteDimension(i2, dimensionRange.length);
        int i3 = dimensionRange[absoluteDimension];
        dimensionRange[absoluteDimension] = absoluteDimension2;
        dimensionRange[absoluteDimension2] = i3;
        return permute(dimensionRange);
    }

    T permute(int... iArr);

    T broadcast(long... jArr);

    FlattenedView<N> getFlattenedView();

    default List<N> asFlatList() {
        return Arrays.asList(asFlatArray());
    }

    default boolean isLengthOne() {
        return getLength() == 1;
    }

    default boolean isScalar() {
        return getRank() == 0;
    }

    default boolean isVector() {
        return getRank() == 1;
    }

    default boolean isMatrix() {
        return getRank() == 2;
    }

    default boolean hasSameShapeAs(Tensor tensor) {
        return hasSameShapeAs(tensor.getShape());
    }

    default boolean hasSameShapeAs(long[] jArr) {
        return Arrays.equals(getShape(), jArr);
    }

    default BooleanTensor elementwiseEquals(Tensor tensor) {
        return elementwiseEquals(this, tensor);
    }

    BooleanTensor elementwiseEquals(N n);
}
