package io.improbable.keanu.tensor.ndj4;

import com.google.common.primitives.Ints;
import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.bool.JVMBooleanTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.tensor.intgr.Nd4jIntegerTensor;
import java.lang.Number;
import java.util.function.Function;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:io/improbable/keanu/tensor/ndj4/Nd4jNumberTensor.class */
public abstract class Nd4jNumberTensor<T extends Number, TENSOR extends NumberTensor<T, TENSOR>> extends Nd4jTensor<T, TENSOR> implements NumberTensor<T, TENSOR> {
    public Nd4jNumberTensor(INDArray iNDArray) {
        super(iNDArray);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR setWithMaskInPlace(TENSOR tensor, T t) {
        if (getLength() != tensor.getLength()) {
            throw new IllegalArgumentException("The lengths of the tensor and mask must match, but got tensor length: " + getLength() + ", mask length: " + tensor.getLength());
        }
        INDArray tensor2 = getTensor(tensor);
        INDArray castTo = this.tensor.dataType() == DataType.DOUBLE ? this.tensor : this.tensor.castTo(DataType.DOUBLE);
        INDArray castTo2 = tensor2.dataType() == DataType.DOUBLE ? tensor2 : tensor2.castTo(DataType.DOUBLE);
        double doubleValue = t.doubleValue();
        double d = 1.0d;
        if (doubleValue == 0.0d) {
            d = 1.0d - 1.0d;
            castTo2.negi().addi(Double.valueOf(1.0d));
        }
        Nd4j.getExecutioner().exec(new CompareAndSet(castTo2, doubleValue, Conditions.equals(Double.valueOf(d))));
        Nd4j.getExecutioner().exec(new CompareAndSet(castTo, castTo2, Conditions.notEquals(Double.valueOf(1.0d - d))));
        return (TENSOR) set(castTo);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR sum(int... iArr) {
        return iArr.length == 0 ? (TENSOR) duplicate() : (TENSOR) create(this.tensor.sum(iArr));
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public T sum() {
        return getNumber(this.tensor.sumNumber());
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR cumSumInPlace(int i) {
        int rank = i >= 0 ? i : i + this.tensor.rank();
        TensorShapeValidation.checkDimensionExistsInShape(rank, this.tensor.shape());
        return (TENSOR) set(this.tensor.cumsumi(rank));
    }

    public T product() {
        return getNumber(this.tensor.prodNumber());
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR product(int... iArr) {
        return (TENSOR) create(this.tensor.prod(iArr));
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR cumProdInPlace(int i) {
        int rank = i >= 0 ? i : i + this.tensor.rank();
        TensorShapeValidation.checkDimensionExistsInShape(rank, this.tensor.shape());
        return (TENSOR) set(INDArrayExtensions.cumProd(this.tensor, rank));
    }

    public TENSOR clampInPlace(TENSOR tensor, TENSOR tensor2) {
        return (TENSOR) minInPlace(tensor2).maxInPlace(tensor);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public T max() {
        return getNumber(this.tensor.maxNumber());
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public T min() {
        return getNumber(this.tensor.minNumber());
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR maxInPlace(TENSOR tensor) {
        if (tensor.isScalar()) {
            Nd4j.getExecutioner().exec(new ScalarMax(this.tensor, (Number) tensor.scalar()));
        } else {
            this.tensor = INDArrayShim.max(this.tensor, getTensor(tensor));
        }
        return (TENSOR) set(this.tensor);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR minInPlace(TENSOR tensor) {
        if (tensor.isScalar()) {
            Nd4j.getExecutioner().exec(new ScalarMin(this.tensor, (Number) tensor.scalar()));
        } else {
            this.tensor = INDArrayShim.min(this.tensor, getTensor(tensor));
        }
        return (TENSOR) set(this.tensor);
    }

    public int argMax() {
        return this.tensor.argMax(new int[0]).getInt(new int[]{0});
    }

    public IntegerTensor argMax(int i) {
        long[] shape = getShape();
        TensorShapeValidation.checkDimensionExistsInShape(i, shape);
        return new Nd4jIntegerTensor(this.tensor.argMax(new int[]{i}).reshape(TensorShape.removeDimension(i, shape)));
    }

    public int argMin() {
        return Nd4j.argMin(this.tensor, new int[0]).getInt(new int[]{0});
    }

    public IntegerTensor argMin(int i) {
        long[] shape = getShape();
        TensorShapeValidation.checkDimensionExistsInShape(i, shape);
        return new Nd4jIntegerTensor(Nd4j.argMin(this.tensor, new int[]{i}).reshape(TensorShape.removeDimension(i, shape)));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.tensor.NumberTensor, io.improbable.keanu.kotlin.NumberOperators
    public TENSOR minus(TENSOR tensor) {
        return tensor.isScalar() ? (TENSOR) minus((Nd4jNumberTensor<T, TENSOR>) tensor.scalar()) : isScalar() ? (TENSOR) tensor.reverseMinus((Number) scalar()) : (TENSOR) ((NumberTensor) duplicate()).minusInPlace(tensor);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.tensor.NumberTensor, io.improbable.keanu.kotlin.NumberOperators
    public TENSOR plus(TENSOR tensor) {
        return tensor.isScalar() ? (TENSOR) plus((Nd4jNumberTensor<T, TENSOR>) tensor.scalar()) : isScalar() ? (TENSOR) tensor.plus((Number) scalar()) : (TENSOR) ((NumberTensor) duplicate()).plusInPlace(tensor);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.tensor.NumberTensor, io.improbable.keanu.kotlin.NumberOperators
    public TENSOR times(TENSOR tensor) {
        return tensor.isScalar() ? (TENSOR) times((Nd4jNumberTensor<T, TENSOR>) tensor.scalar()) : isScalar() ? (TENSOR) tensor.times((Number) scalar()) : (TENSOR) ((NumberTensor) duplicate()).timesInPlace(tensor);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.tensor.NumberTensor, io.improbable.keanu.kotlin.NumberOperators
    public TENSOR div(TENSOR tensor) {
        return tensor.isScalar() ? (TENSOR) div((Nd4jNumberTensor<T, TENSOR>) tensor.scalar()) : isScalar() ? (TENSOR) tensor.reverseDiv((Number) scalar()) : (TENSOR) ((NumberTensor) duplicate()).divInPlace(tensor);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR minusInPlace(T t) {
        return (TENSOR) set(this.tensor.subi(t));
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR plusInPlace(T t) {
        return (TENSOR) set(this.tensor.addi(t));
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR timesInPlace(T t) {
        return (TENSOR) set(this.tensor.muli(t));
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR divInPlace(T t) {
        return (TENSOR) set(this.tensor.divi(t));
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR powInPlace(TENSOR tensor) {
        if (tensor.isScalar()) {
            Transforms.pow(this.tensor, (Number) tensor.scalar(), false);
        } else {
            this.tensor = INDArrayShim.pow(this.tensor, getTensor(tensor));
        }
        return (TENSOR) set(this.tensor);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR powInPlace(T t) {
        Transforms.pow(this.tensor, t, false);
        return (TENSOR) set(this.tensor);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR minusInPlace(TENSOR tensor) {
        if (tensor.isScalar()) {
            this.tensor.subi((Number) tensor.scalar());
        } else {
            if (isScalar()) {
                return minus((Nd4jNumberTensor<T, TENSOR>) tensor);
            }
            INDArray subi = INDArrayShim.subi(this.tensor, getTensor(tensor).dup());
            if (subi != this.tensor) {
                return (TENSOR) create(subi);
            }
        }
        return (TENSOR) set(this.tensor);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR reverseMinusInPlace(TENSOR tensor) {
        if (tensor.isScalar()) {
            this.tensor.subi((Number) tensor.scalar());
        } else {
            if (isScalar()) {
                return minus((Nd4jNumberTensor<T, TENSOR>) tensor);
            }
            INDArray rsubi = INDArrayShim.rsubi(this.tensor, getTensor(tensor).dup());
            if (rsubi != this.tensor) {
                return (TENSOR) create(rsubi);
            }
        }
        return (TENSOR) set(this.tensor);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR reverseMinusInPlace(T t) {
        return (TENSOR) set(this.tensor.rsubi(t));
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR plusInPlace(TENSOR tensor) {
        if (tensor.isScalar()) {
            this.tensor.addi((Number) tensor.scalar());
        } else {
            if (isScalar()) {
                return plus((Nd4jNumberTensor<T, TENSOR>) tensor);
            }
            INDArray addi = INDArrayShim.addi(this.tensor, getTensor(tensor).dup());
            if (addi != this.tensor) {
                return (TENSOR) create(addi);
            }
        }
        return (TENSOR) set(this.tensor);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR timesInPlace(TENSOR tensor) {
        if (tensor.isScalar()) {
            this.tensor.muli((Number) tensor.scalar());
        } else {
            if (isScalar()) {
                return times((Nd4jNumberTensor<T, TENSOR>) tensor);
            }
            INDArray muli = INDArrayShim.muli(this.tensor, getTensor(tensor).dup());
            if (muli != this.tensor) {
                return (TENSOR) create(muli);
            }
        }
        return (TENSOR) set(this.tensor);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR divInPlace(TENSOR tensor) {
        if (tensor.isScalar()) {
            this.tensor.divi((Number) tensor.scalar());
        } else {
            if (isScalar()) {
                return div((Nd4jNumberTensor<T, TENSOR>) tensor);
            }
            INDArray divi = INDArrayShim.divi(this.tensor, getTensor(tensor).dup());
            if (divi != this.tensor) {
                return (TENSOR) create(divi);
            }
        }
        return (TENSOR) set(this.tensor);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR reverseDivInPlace(T t) {
        return (TENSOR) set(this.tensor.rdivi(t));
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR reverseDivInPlace(TENSOR tensor) {
        if (tensor.isScalar()) {
            this.tensor.subi((Number) tensor.scalar());
        } else {
            if (isScalar()) {
                return minus((Nd4jNumberTensor<T, TENSOR>) tensor);
            }
            INDArray rdivi = INDArrayShim.rdivi(this.tensor, getTensor(tensor).dup());
            if (rdivi != this.tensor) {
                return (TENSOR) create(rdivi);
            }
        }
        return (TENSOR) set(this.tensor);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR unaryMinusInPlace() {
        return (TENSOR) set(this.tensor.negi());
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR absInPlace() {
        Transforms.abs(this.tensor, false);
        return (TENSOR) set(this.tensor);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public T average() {
        return getNumber(this.tensor.meanNumber());
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public T standardDeviation() {
        return getNumber(this.tensor.stdNumber());
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public BooleanTensor greaterThan(T t) {
        return fromMask(this.tensor.gt(t));
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public BooleanTensor greaterThanOrEqual(T t) {
        return fromMask(this.tensor.gte(t));
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public BooleanTensor greaterThan(TENSOR tensor) {
        INDArray gt;
        if (tensor.isScalar()) {
            gt = this.tensor.gt((Number) tensor.scalar());
        } else {
            gt = INDArrayShim.gt(this.tensor, getTensor(tensor));
        }
        return fromMask(gt);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public BooleanTensor greaterThanOrEqual(TENSOR tensor) {
        INDArray gte;
        if (tensor.isScalar()) {
            gte = this.tensor.gte((Number) tensor.scalar());
        } else {
            gte = INDArrayShim.gte(this.tensor.dup(), getTensor(tensor));
        }
        return fromMask(gte);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public BooleanTensor lessThan(T t) {
        return fromMask(this.tensor.lt(t));
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public BooleanTensor lessThanOrEqual(T t) {
        return fromMask(this.tensor.lte(t));
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public BooleanTensor lessThan(TENSOR tensor) {
        INDArray lt;
        if (tensor.isScalar()) {
            lt = this.tensor.lt((Number) tensor.scalar());
        } else {
            lt = INDArrayShim.lt(this.tensor, getTensor(tensor));
        }
        return fromMask(lt);
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public BooleanTensor lessThanOrEqual(TENSOR tensor) {
        INDArray lte;
        if (tensor.isScalar()) {
            lte = this.tensor.lte((Number) tensor.scalar());
        } else {
            lte = INDArrayShim.lte(this.tensor.dup(), getTensor(tensor));
        }
        return fromMask(lte);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR applyInPlace(Function<T, T> function) {
        Tensor.FlattenedView<N> flattenedView = getFlattenedView();
        for (int i = 0; i < flattenedView.size(); i++) {
            flattenedView.set(i, function.apply(flattenedView.get(i)));
        }
        return (TENSOR) getThis();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.tensor.Tensor
    public BooleanTensor elementwiseEquals(Tensor tensor) {
        return tensor instanceof NumberTensor ? isScalar() ? tensor.elementwiseEquals((Tensor) scalar()) : tensor.isScalar() ? elementwiseEquals((Nd4jNumberTensor<T, TENSOR>) tensor.scalar()) : fromMask(INDArrayShim.eq(this.tensor, getTensor(tensor))) : Tensor.elementwiseEquals(this, tensor);
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public BooleanTensor elementwiseEquals(T t) {
        return fromMask(this.tensor.eq(t));
    }

    public TENSOR matrixMultiply(TENSOR tensor) {
        return (TENSOR) set(this.tensor.mmul(getTensor(tensor)));
    }

    /* JADX WARN: Type inference failed for: r3v2, types: [int[], int[][]] */
    public TENSOR tensorMultiply(TENSOR tensor, int[] iArr, int[] iArr2) {
        return (TENSOR) set(Nd4j.tensorMmul(this.tensor, getTensor(tensor), (int[][]) new int[]{iArr, iArr2}));
    }

    protected final BooleanTensor fromMask(INDArray iNDArray) {
        long[] shape = iNDArray.shape();
        DataBuffer data = iNDArray.data();
        boolean[] zArr = new boolean[Ints.checkedCast(iNDArray.length())];
        for (int i = 0; i < zArr.length; i++) {
            zArr[i] = data.getNumber((long) i).intValue() != 0;
        }
        return JVMBooleanTensor.create(zArr, shape);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract T getNumber(Number number);
}
