package io.improbable.keanu.tensor.generic;

import com.google.common.primitives.Ints;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.buffer.JVMBuffer;
import io.improbable.keanu.tensor.generic.GenericBuffer;
import io.improbable.keanu.tensor.jvm.JVMTensor;
import io.improbable.keanu.tensor.jvm.JVMTensorBroadcast;
import io.improbable.keanu.tensor.jvm.ResultWrapper;
import java.util.Arrays;
import java.util.function.BiFunction;

/* loaded from: input_file:io/improbable/keanu/tensor/generic/GenericTensor.class */
public class GenericTensor<T> extends JVMTensor<T, GenericTensor<T>, GenericBuffer.PrimitiveGenericWrapper<T>> implements Tensor<T, GenericTensor<T>> {
    private static final GenericBuffer.GenericArrayWrapperFactory factory = new GenericBuffer.GenericArrayWrapperFactory();

    /* loaded from: input_file:io/improbable/keanu/tensor/generic/GenericTensor$BaseSimpleFlattenedView.class */
    private static class BaseSimpleFlattenedView<T> implements Tensor.FlattenedView<T> {
        T[] data;

        public BaseSimpleFlattenedView(T[] tArr) {
            this.data = tArr;
        }

        @Override // io.improbable.keanu.tensor.Tensor.FlattenedView
        public long size() {
            return this.data.length;
        }

        @Override // io.improbable.keanu.tensor.Tensor.FlattenedView
        public T get(long j) {
            if (j > 2147483647L) {
                throw new IllegalArgumentException("Only integer based indexing supported for generic tensors");
            }
            return this.data[(int) j];
        }

        @Override // io.improbable.keanu.tensor.Tensor.FlattenedView
        public T getOrScalar(long j) {
            return this.data.length == 1 ? get(0L) : get(j);
        }

        @Override // io.improbable.keanu.tensor.Tensor.FlattenedView
        public void set(long j, T t) {
            if (j > 2147483647L) {
                throw new IllegalArgumentException("Only integer based indexing supported for generic tensors");
            }
            this.data[(int) j] = t;
        }
    }

    public static <T> GenericTensor<T> createFilled(T t, long[] jArr) {
        return new GenericTensor<>(fillArray(jArr, t), jArr);
    }

    public static <T> GenericTensor<T> create(T... tArr) {
        return create((Object[]) tArr, new long[]{tArr.length});
    }

    public static <T> GenericTensor<T> create(T[] tArr, long[] jArr) {
        if (TensorShape.getLength(jArr) != tArr.length) {
            throw new IllegalArgumentException("Shape size does not match data length");
        }
        return new GenericTensor<>(tArr, jArr);
    }

    public static <T> GenericTensor<T> scalar(T t) {
        return new GenericTensor<>(t);
    }

    private GenericTensor(GenericBuffer.PrimitiveGenericWrapper<T> primitiveGenericWrapper, long[] jArr, long[] jArr2) {
        super(primitiveGenericWrapper, jArr, jArr2);
    }

    private GenericTensor(GenericBuffer.PrimitiveGenericWrapper<T> primitiveGenericWrapper, long[] jArr) {
        super(primitiveGenericWrapper, jArr, TensorShape.getRowFirstStride(jArr));
    }

    private GenericTensor(T[] tArr, long[] jArr, long[] jArr2) {
        super(factory.create(tArr), jArr, jArr2);
    }

    private GenericTensor(T[] tArr, long[] jArr) {
        this(tArr, jArr, TensorShape.getRowFirstStride(jArr));
    }

    private GenericTensor(T t) {
        super(new GenericBuffer.GenericWrapper(t), new long[0], new long[0]);
    }

    private static <T> T[] fillArray(long[] jArr, T t) {
        Object[] objArr = new Object[TensorShape.getLengthAsInt(jArr)];
        Arrays.fill(objArr, t);
        return (T[]) objArr;
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public GenericTensor<T> duplicate() {
        return new GenericTensor<>(((GenericBuffer.PrimitiveGenericWrapper) this.buffer).copy(), Arrays.copyOf(this.shape, this.shape.length), Arrays.copyOf(this.stride, this.stride.length));
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public Tensor.FlattenedView<T> getFlattenedView() {
        return new BaseSimpleFlattenedView(((GenericBuffer.PrimitiveGenericWrapper) this.buffer).asArray());
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public BooleanTensor elementwiseEquals(T t) {
        boolean[] zArr = new boolean[Ints.checkedCast(((GenericBuffer.PrimitiveGenericWrapper) this.buffer).getLength())];
        for (int i = 0; i < ((GenericBuffer.PrimitiveGenericWrapper) this.buffer).getLength(); i++) {
            zArr[i] = ((GenericBuffer.PrimitiveGenericWrapper) this.buffer).get(i).equals(t);
        }
        return BooleanTensor.create(zArr, this.shape);
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public T[] asFlatArray() {
        return ((GenericBuffer.PrimitiveGenericWrapper) this.buffer).copy().asArray();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.improbable.keanu.tensor.jvm.JVMTensor
    public GenericTensor<T> create(GenericBuffer.PrimitiveGenericWrapper<T> primitiveGenericWrapper, long[] jArr, long[] jArr2) {
        return new GenericTensor<>(primitiveGenericWrapper, jArr, jArr2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.improbable.keanu.tensor.jvm.JVMTensor
    public GenericTensor<T> set(GenericBuffer.PrimitiveGenericWrapper<T> primitiveGenericWrapper, long[] jArr, long[] jArr2) {
        this.buffer = primitiveGenericWrapper;
        this.shape = jArr;
        this.stride = jArr2;
        return this;
    }

    @Override // io.improbable.keanu.tensor.jvm.JVMTensor
    protected JVMBuffer.ArrayWrapperFactory<T, GenericBuffer.PrimitiveGenericWrapper<T>> getFactory() {
        return factory;
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public GenericTensor<T> take(long... jArr) {
        return scalar((Object) getValue(jArr));
    }

    private static <T> GenericTensor<T> getRawBufferIfJVMTensor(Tensor<T, ?> tensor) {
        return tensor instanceof GenericTensor ? (GenericTensor) tensor : new GenericTensor<>(factory.create(tensor.asFlatArray()), tensor.getShape(), tensor.getStride());
    }

    public <R> GenericTensor<R> apply(Tensor<T, ?> tensor, BiFunction<T, T, R> biFunction) {
        GenericTensor rawBufferIfJVMTensor = getRawBufferIfJVMTensor(tensor);
        ResultWrapper broadcastIfNeeded = JVMTensorBroadcast.broadcastIfNeeded(factory, this.buffer, this.shape, this.stride, ((GenericBuffer.PrimitiveGenericWrapper) this.buffer).getLength(), rawBufferIfJVMTensor.buffer, rawBufferIfJVMTensor.shape, rawBufferIfJVMTensor.stride, ((GenericBuffer.PrimitiveGenericWrapper) rawBufferIfJVMTensor.buffer).getLength(), biFunction, false);
        return new GenericTensor<>((GenericBuffer.PrimitiveGenericWrapper) broadcastIfNeeded.outputBuffer, broadcastIfNeeded.outputShape, broadcastIfNeeded.outputStride);
    }
}
