package io.improbable.keanu.tensor.dbl;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.tensor.intgr.Nd4jIntegerTensor;
import io.improbable.keanu.tensor.ndj4.INDArrayExtensions;
import io.improbable.keanu.tensor.ndj4.INDArrayShim;
import io.improbable.keanu.tensor.ndj4.Nd4jFloatingPointTensor;
import io.improbable.keanu.tensor.ndj4.Nd4jTensor;
import io.improbable.keanu.tensor.ndj4.TypedINDArrayFactory;
import io.improbable.keanu.tensor.validate.TensorValidator;
import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.special.Gamma;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;

/* loaded from: input_file:io/improbable/keanu/tensor/dbl/Nd4jDoubleTensor.class */
public class Nd4jDoubleTensor extends Nd4jFloatingPointTensor<Double, DoubleTensor> implements DoubleTensor {
    private static final DataType BUFFER_TYPE;

    /* loaded from: input_file:io/improbable/keanu/tensor/dbl/Nd4jDoubleTensor$Nd4jDoubleFlattenedView.class */
    private class Nd4jDoubleFlattenedView implements Tensor.FlattenedView<Double> {
        private Nd4jDoubleFlattenedView() {
        }

        @Override // io.improbable.keanu.tensor.Tensor.FlattenedView
        public long size() {
            return Nd4jDoubleTensor.this.tensor.data().length();
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // io.improbable.keanu.tensor.Tensor.FlattenedView
        public Double get(long j) {
            return Double.valueOf(Nd4jDoubleTensor.this.tensor.data().getDouble(j));
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // io.improbable.keanu.tensor.Tensor.FlattenedView
        public Double getOrScalar(long j) {
            return Nd4jDoubleTensor.this.tensor.length() == 1 ? get(0L) : get(j);
        }

        @Override // io.improbable.keanu.tensor.Tensor.FlattenedView
        public void set(long j, Double d) {
            Nd4jDoubleTensor.this.tensor.data().put(j, d.doubleValue());
        }
    }

    public Nd4jDoubleTensor(double[] dArr, long[] jArr) {
        this(TypedINDArrayFactory.create(dArr, jArr));
    }

    public Nd4jDoubleTensor(INDArray iNDArray) {
        super(iNDArray);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.improbable.keanu.tensor.ndj4.Nd4jTensor
    public INDArray getTensor(Tensor tensor) {
        return getAsINDArray(tensor);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.improbable.keanu.tensor.ndj4.Nd4jTensor
    public DoubleTensor create(INDArray iNDArray) {
        return new Nd4jDoubleTensor(iNDArray);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.improbable.keanu.tensor.ndj4.Nd4jTensor
    public DoubleTensor set(INDArray iNDArray) {
        this.tensor = iNDArray.dataType() == DataType.DOUBLE ? iNDArray : iNDArray.castTo(DataType.DOUBLE);
        return this;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.improbable.keanu.tensor.ndj4.Nd4jTensor
    public DoubleTensor getThis() {
        return this;
    }

    public static Nd4jDoubleTensor scalar(double d) {
        return new Nd4jDoubleTensor(Nd4j.scalar(d));
    }

    public static Nd4jDoubleTensor create(double[] dArr, long... jArr) {
        if (dArr.length != TensorShape.getLength(jArr)) {
            throw new IllegalArgumentException("Shape " + Arrays.toString(jArr) + " does not match buffer size " + dArr.length);
        }
        return new Nd4jDoubleTensor(dArr, jArr);
    }

    public static Nd4jDoubleTensor create(double d, long... jArr) {
        return new Nd4jDoubleTensor(Nd4j.valueArrayOf(jArr, d));
    }

    public static Nd4jDoubleTensor create(double[] dArr) {
        return create(dArr, dArr.length);
    }

    public static Nd4jDoubleTensor ones(long... jArr) {
        return new Nd4jDoubleTensor(TypedINDArrayFactory.ones(jArr, BUFFER_TYPE));
    }

    public static Nd4jDoubleTensor eye(long j) {
        return new Nd4jDoubleTensor(TypedINDArrayFactory.eye(j, BUFFER_TYPE));
    }

    public static Nd4jDoubleTensor zeros(long[] jArr) {
        return new Nd4jDoubleTensor(TypedINDArrayFactory.zeros(jArr, BUFFER_TYPE));
    }

    public static Nd4jDoubleTensor linspace(double d, double d2, int i) {
        return new Nd4jDoubleTensor(TypedINDArrayFactory.linspace(d, d2, i, BUFFER_TYPE));
    }

    public static Nd4jDoubleTensor arange(double d, double d2) {
        return new Nd4jDoubleTensor(TypedINDArrayFactory.arange(d, d2));
    }

    public static Nd4jDoubleTensor arange(double d, double d2, double d3) {
        return new Nd4jDoubleTensor(TypedINDArrayFactory.arange(0.0d, Math.ceil((d2 - d) / d3)).muli(Double.valueOf(d3)).addi(Double.valueOf(d)));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static INDArray getAsINDArray(Tensor tensor) {
        if (tensor instanceof Nd4jTensor) {
            INDArray tensor2 = ((Nd4jTensor) tensor).getTensor();
            return tensor2.dataType() == DataType.DOUBLE ? tensor2 : tensor2.castTo(DataType.DOUBLE);
        }
        if (tensor instanceof NumberTensor) {
            return TypedINDArrayFactory.create(((NumberTensor) tensor).toDouble().asFlatDoubleArray(), tensor.getShape());
        }
        throw new IllegalArgumentException("Cannot convert " + tensor.getClass().getSimpleName() + " to double INDArray/");
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public int nanArgMax() {
        return this.tensor.argMax(new int[0]).getInt(new int[]{0});
    }

    @Override // io.improbable.keanu.tensor.ndj4.Nd4jNumberTensor, io.improbable.keanu.tensor.NumberTensor
    public int argMax() {
        return ((DoubleTensor) duplicate()).replaceNaNInPlace(Double.valueOf(Double.MAX_VALUE)).nanArgMax();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public IntegerTensor nanArgMax(int i) {
        long[] shape = getShape();
        TensorShapeValidation.checkDimensionExistsInShape(i, shape);
        return new Nd4jIntegerTensor(this.tensor.argMax(new int[]{i}).reshape(TensorShape.removeDimension(i, shape)));
    }

    @Override // io.improbable.keanu.tensor.ndj4.Nd4jNumberTensor, io.improbable.keanu.tensor.NumberTensor
    public IntegerTensor argMax(int i) {
        return ((DoubleTensor) duplicate()).replaceNaNInPlace(Double.valueOf(Double.MAX_VALUE)).nanArgMax(i);
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public int nanArgMin() {
        return Nd4j.argMin(this.tensor, new int[0]).getInt(new int[]{0});
    }

    @Override // io.improbable.keanu.tensor.ndj4.Nd4jNumberTensor, io.improbable.keanu.tensor.NumberTensor
    public int argMin() {
        return ((DoubleTensor) duplicate()).replaceNaNInPlace(Double.valueOf(-1.7976931348623157E308d)).nanArgMin();
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public IntegerTensor nanArgMin(int i) {
        long[] shape = getShape();
        TensorShapeValidation.checkDimensionExistsInShape(i, shape);
        return new Nd4jIntegerTensor(Nd4j.argMin(this.tensor, new int[]{i}).reshape(TensorShape.removeDimension(i, shape)));
    }

    @Override // io.improbable.keanu.tensor.ndj4.Nd4jNumberTensor, io.improbable.keanu.tensor.NumberTensor
    public IntegerTensor argMin(int i) {
        return ((DoubleTensor) duplicate()).replaceNaNInPlace(Double.valueOf(-1.7976931348623157E308d)).nanArgMin(i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.improbable.keanu.tensor.ndj4.Nd4jNumberTensor
    public Double getNumber(Number number) {
        return Double.valueOf(number.doubleValue());
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.tensor.NumberTensor
    public boolean equalsWithinEpsilon(DoubleTensor doubleTensor, Double d) {
        if (this == doubleTensor) {
            return true;
        }
        if (doubleTensor instanceof Nd4jTensor) {
            return this.tensor.equalsWithEps(((Nd4jTensor) doubleTensor).getTensor(), d.doubleValue());
        }
        if (hasSameShapeAs(doubleTensor)) {
            return ((DoubleTensor) ((DoubleTensor) doubleTensor.minus((DoubleTensor) this)).abs()).lessThan((DoubleTensor) d).allTrue();
        }
        return false;
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public DoubleTensor greaterThanMask(DoubleTensor doubleTensor) {
        return greaterThan((Nd4jDoubleTensor) doubleTensor).toDoubleMask();
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public DoubleTensor greaterThanOrEqualToMask(DoubleTensor doubleTensor) {
        return greaterThanOrEqual((Nd4jDoubleTensor) doubleTensor).toDoubleMask();
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public DoubleTensor lessThanMask(DoubleTensor doubleTensor) {
        return lessThan((Nd4jDoubleTensor) doubleTensor).toDoubleMask();
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public DoubleTensor lessThanOrEqualToMask(DoubleTensor doubleTensor) {
        return lessThanOrEqual((Nd4jDoubleTensor) doubleTensor).toDoubleMask();
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public DoubleTensor safeLogTimesInPlace(DoubleTensor doubleTensor) {
        TensorValidator.NAN_CATCHER.validate(getThis());
        TensorValidator.NAN_CATCHER.validate(doubleTensor);
        return TensorValidator.NAN_FIXER.validate((DoubleTensor) logInPlace().timesInPlace(doubleTensor));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public DoubleTensor logGammaInPlace() {
        return (DoubleTensor) applyInPlace((v0) -> {
            return Gamma.logGamma(v0);
        });
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public DoubleTensor digammaInPlace() {
        return (DoubleTensor) applyInPlace((v0) -> {
            return Gamma.digamma(v0);
        });
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public DoubleTensor logAddExp2InPlace(DoubleTensor doubleTensor) {
        DoubleTensor logAddExp2InPlace = JVMDoubleTensor.create(this.tensor.toDoubleVector(), this.tensor.shape()).logAddExp2InPlace(doubleTensor);
        return create(logAddExp2InPlace.asFlatDoubleArray(), logAddExp2InPlace.getShape());
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public DoubleTensor logAddExpInPlace(DoubleTensor doubleTensor) {
        DoubleTensor logAddExpInPlace = JVMDoubleTensor.create(this.tensor.toDoubleVector(), this.tensor.shape()).logAddExpInPlace(doubleTensor);
        return create(logAddExpInPlace.asFlatDoubleArray(), logAddExpInPlace.getShape());
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public DoubleTensor replaceNaNInPlace(Double d) {
        Nd4j.getExecutioner().exec(new ReplaceNans(this.tensor, d.doubleValue()));
        return this;
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public BooleanTensor isFinite() {
        return BooleanTensor.create(asBoolean(Nd4j.getExecutioner().exec(new MatchConditionTransform(this.tensor, Nd4j.createUninitialized(DataType.BOOL, this.tensor.shape(), this.tensor.ordering()), Conditions.isFinite()))), this.tensor.shape());
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public BooleanTensor isInfinite() {
        return BooleanTensor.create(asBoolean(this.tensor.isInfinite()), this.tensor.shape());
    }

    private boolean[] asBoolean(INDArray iNDArray) {
        Preconditions.checkArgument(iNDArray.dataType() == DataType.BOOL);
        boolean[] zArr = new boolean[Ints.checkedCast(iNDArray.length())];
        for (int i = 0; i < zArr.length; i++) {
            zArr[i] = iNDArray.data().indexer().get(i);
        }
        return zArr;
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public BooleanTensor isNegativeInfinity() {
        return BooleanTensor.create(asBoolean(Nd4j.getExecutioner().exec(new MatchConditionTransform(this.tensor, Nd4j.createUninitialized(DataType.BOOL, this.tensor.shape(), this.tensor.ordering()), Conditions.equals(Double.valueOf(Double.NEGATIVE_INFINITY))))), this.tensor.shape());
    }

    @Override // io.improbable.keanu.tensor.FloatingPointTensor
    public BooleanTensor isPositiveInfinity() {
        return BooleanTensor.create(asBoolean(Nd4j.getExecutioner().exec(new MatchConditionTransform(this.tensor, Nd4j.createUninitialized(DataType.BOOL, this.tensor.shape(), this.tensor.ordering()), Conditions.equals(Double.valueOf(Double.POSITIVE_INFINITY))))), this.tensor.shape());
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public DoubleTensor toDouble() {
        return (DoubleTensor) duplicate();
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public IntegerTensor toInteger() {
        return new Nd4jIntegerTensor(INDArrayExtensions.castToInteger(this.tensor, true));
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public double[] asFlatDoubleArray() {
        return this.tensor.dup().data().asDouble();
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public int[] asFlatIntegerArray() {
        return this.tensor.dup().data().asInt();
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public Double[] asFlatArray() {
        return ArrayUtils.toObject(asFlatDoubleArray());
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public Tensor.FlattenedView<Double> getFlattenedView() {
        return new Nd4jDoubleFlattenedView();
    }

    static {
        INDArrayShim.startNewThreadForNd4j();
        BUFFER_TYPE = DataType.DOUBLE;
    }
}
