package io.improbable.keanu.tensor.ndj4;

import io.improbable.keanu.tensor.FloatingPointTensor;
import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.TensorMulByMatrixMul;
import java.lang.Number;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.LUDecomposition;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.LogX;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tan;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.inverse.InvertMatrix;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:io/improbable/keanu/tensor/ndj4/Nd4jFloatingPointTensor.class */
public abstract class Nd4jFloatingPointTensor<T extends Number, TENSOR extends FloatingPointTensor<T, TENSOR>> extends Nd4jNumberTensor<T, TENSOR> implements FloatingPointTensor<T, TENSOR> {
    public Nd4jFloatingPointTensor(INDArray iNDArray) {
        super(iNDArray);
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR matrixInverse() {
        return (TENSOR) create(InvertMatrix.invert(this.tensor, false));
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR choleskyDecomposition() {
        INDArray dup = this.tensor.dup();
        Nd4j.getBlasWrapper().lapack().potrf(dup, true);
        return (TENSOR) create(dup);
    }

    @Override // io.improbable.keanu.tensor.ndj4.Nd4jNumberTensor, io.improbable.keanu.tensor.NumberTensor
    public TENSOR matrixMultiply(TENSOR tensor) {
        TensorShapeValidation.getMatrixMultiplicationResultingShape(this.tensor.shape(), tensor.getShape());
        return (TENSOR) create(this.tensor.mmul(getTensor(tensor)));
    }

    @Override // io.improbable.keanu.tensor.ndj4.Nd4jNumberTensor, io.improbable.keanu.tensor.NumberTensor
    public TENSOR tensorMultiply(TENSOR tensor, int[] iArr, int[] iArr2) {
        return (TENSOR) TensorMulByMatrixMul.tensorMmul((FloatingPointTensor) getThis(), tensor, iArr, iArr2);
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR sqrtInPlace() {
        Transforms.sqrt(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR logInPlace() {
        Transforms.log(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR reciprocalInPlace() {
        this.tensor.rdivi(Double.valueOf(1.0d));
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR sinInPlace() {
        Transforms.sin(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR cosInPlace() {
        Transforms.cos(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR tanInPlace() {
        Nd4j.getExecutioner().exec(new Tan(this.tensor, this.tensor));
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR atanInPlace() {
        Transforms.atan(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR atan2InPlace(T t) {
        this.tensor = Transforms.atan2(this.tensor, Nd4j.scalar(t).broadcast(this.tensor.shape()));
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR atan2InPlace(TENSOR tensor) {
        if (tensor.isScalar()) {
            this.tensor = Transforms.atan2(this.tensor, getTensor(tensor).broadcast(this.tensor.shape()));
        } else {
            this.tensor = INDArrayShim.atan2(this.tensor, getTensor(tensor));
        }
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR asinInPlace() {
        Transforms.asin(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR acosInPlace() {
        Transforms.acos(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR sinhInPlace() {
        Transforms.sinh(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR coshInPlace() {
        Transforms.cosh(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR tanhInPlace() {
        Transforms.tanh(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR asinhInPlace() {
        Nd4j.getExecutioner().execAndReturn(new ASinh(this.tensor));
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR acoshInPlace() {
        Nd4j.getExecutioner().execAndReturn(new ACosh(this.tensor));
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR atanhInPlace() {
        Transforms.atanh(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR expInPlace() {
        Transforms.exp(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR log1pInPlace() {
        Nd4j.getExecutioner().exec(new Log1p(this.tensor));
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR log2InPlace() {
        Nd4j.getExecutioner().exec(new LogX(this.tensor, 2.0d));
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR log10InPlace() {
        Nd4j.getExecutioner().exec(new LogX(this.tensor, 10.0d));
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR exp2InPlace() {
        Nd4j.getExecutioner().exec(new PowPairwise(Nd4j.valueArrayOf(this.tensor.shape(), 2.0d), this.tensor, this.tensor));
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR expM1InPlace() {
        Nd4j.getExecutioner().exec(new Expm1(this.tensor));
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR standardizeInPlace() {
        this.tensor.subi(average()).divi(standardDeviation());
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public TENSOR setAllInPlace(T t) {
        this.tensor.assign(t);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.ndj4.Nd4jNumberTensor, io.improbable.keanu.tensor.NumberTensor
    public TENSOR clampInPlace(TENSOR tensor, TENSOR tensor2) {
        return (TENSOR) ((FloatingPointTensor) minInPlace(tensor2)).maxInPlace(tensor);
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR ceilInPlace() {
        Transforms.ceil(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR floorInPlace() {
        Transforms.floor(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR roundInPlace() {
        Transforms.round(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public TENSOR sigmoidInPlace() {
        Transforms.sigmoid(this.tensor, false);
        return (TENSOR) getThis();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public T determinant() {
        return getNumber(Double.valueOf(new LUDecomposition(new Array2DRowRealMatrix(this.tensor.dup().toDoubleMatrix())).getDeterminant()));
    }

    @Override // io.improbable.keanu.tensor.ndj4.Nd4jNumberTensor, io.improbable.keanu.tensor.NumberTensor
    public T product() {
        return getNumber(this.tensor.prod(new int[0]).getNumber(0L));
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public BooleanTensor notNaN() {
        return elementwiseEquals((Tensor) this);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.tensor.NumberTensor
    public /* bridge */ /* synthetic */ NumberTensor setAllInPlace(Number number) {
        return setAllInPlace((Nd4jFloatingPointTensor<T, TENSOR>) number);
    }
}
