package io.improbable.keanu.tensor.intgr;

import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.dbl.Nd4jDoubleTensor;
import io.improbable.keanu.tensor.ndj4.INDArrayShim;
import io.improbable.keanu.tensor.ndj4.Nd4jFixedPointTensor;
import io.improbable.keanu.tensor.ndj4.Nd4jTensor;
import io.improbable.keanu.tensor.ndj4.TypedINDArrayFactory;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.NotImplementedException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:io/improbable/keanu/tensor/intgr/Nd4jIntegerTensor.class */
public class Nd4jIntegerTensor extends Nd4jFixedPointTensor<Integer, IntegerTensor> implements IntegerTensor {
    private static final DataType BUFFER_TYPE;

    /* loaded from: input_file:io/improbable/keanu/tensor/intgr/Nd4jIntegerTensor$Nd4jIntegerFlattenedView.class */
    private class Nd4jIntegerFlattenedView implements Tensor.FlattenedView<Integer> {
        private Nd4jIntegerFlattenedView() {
        }

        @Override // io.improbable.keanu.tensor.Tensor.FlattenedView
        public long size() {
            return Nd4jIntegerTensor.this.tensor.data().length();
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // io.improbable.keanu.tensor.Tensor.FlattenedView
        public Integer get(long j) {
            return Integer.valueOf(Nd4jIntegerTensor.this.tensor.data().getInt(j));
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // io.improbable.keanu.tensor.Tensor.FlattenedView
        public Integer getOrScalar(long j) {
            return Nd4jIntegerTensor.this.tensor.length() == 1 ? get(0L) : get(j);
        }

        @Override // io.improbable.keanu.tensor.Tensor.FlattenedView
        public void set(long j, Integer num) {
            Nd4jIntegerTensor.this.tensor.data().put(j, num.intValue());
        }
    }

    public Nd4jIntegerTensor(int[] iArr, long[] jArr) {
        this(TypedINDArrayFactory.create(iArr, jArr));
    }

    public Nd4jIntegerTensor(INDArray iNDArray) {
        super(iNDArray);
    }

    @Override // io.improbable.keanu.tensor.ndj4.Nd4jTensor
    protected INDArray getTensor(Tensor<Integer, ?> tensor) {
        return getAsINDArray(tensor);
    }

    public static Nd4jIntegerTensor scalar(int i) {
        return new Nd4jIntegerTensor(Nd4j.scalar(i));
    }

    public static Nd4jIntegerTensor create(int[] iArr, long[] jArr) {
        return new Nd4jIntegerTensor(iArr, jArr);
    }

    public static Nd4jIntegerTensor create(int i, long[] jArr) {
        return new Nd4jIntegerTensor(Nd4j.valueArrayOf(jArr, i));
    }

    public static Nd4jIntegerTensor ones(long[] jArr) {
        return new Nd4jIntegerTensor(TypedINDArrayFactory.ones(jArr, BUFFER_TYPE));
    }

    public static Nd4jIntegerTensor eye(long j) {
        return new Nd4jIntegerTensor(TypedINDArrayFactory.eye(j, BUFFER_TYPE));
    }

    public static Nd4jIntegerTensor zeros(long[] jArr) {
        return new Nd4jIntegerTensor(TypedINDArrayFactory.zeros(jArr, BUFFER_TYPE));
    }

    public static Nd4jIntegerTensor arange(int i, int i2) {
        return new Nd4jIntegerTensor(TypedINDArrayFactory.arange(i, i2));
    }

    public static Nd4jDoubleTensor arange(int i, int i2, int i3) {
        return new Nd4jDoubleTensor(TypedINDArrayFactory.arange(0, (i2 - i) / i3).muli(Integer.valueOf(i3)).addi(Integer.valueOf(i)));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static INDArray getAsINDArray(Tensor<Integer, ?> tensor) {
        if (tensor instanceof Nd4jTensor) {
            INDArray tensor2 = ((Nd4jTensor) tensor).getTensor();
            return tensor2.dataType() == DataType.INT ? tensor2 : tensor2.castTo(DataType.INT);
        }
        if (tensor instanceof NumberTensor) {
            return TypedINDArrayFactory.create(((NumberTensor) tensor).toInteger().asFlatIntegerArray(), tensor.getShape());
        }
        throw new IllegalArgumentException("Cannot convert " + tensor.getClass().getSimpleName() + " to double INDArray/");
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.tensor.NumberTensor
    public boolean equalsWithinEpsilon(IntegerTensor integerTensor, Integer num) {
        if (this == integerTensor) {
            return true;
        }
        if (integerTensor instanceof Nd4jTensor) {
            return this.tensor.equalsWithEps(((Nd4jTensor) integerTensor).getTensor(), num.intValue());
        }
        if (hasSameShapeAs(integerTensor)) {
            return ((IntegerTensor) ((IntegerTensor) integerTensor.minus((IntegerTensor) this)).abs()).lessThan((IntegerTensor) num).allTrue();
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.improbable.keanu.tensor.ndj4.Nd4jNumberTensor
    public Integer getNumber(Number number) {
        return Integer.valueOf(number.intValue());
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public IntegerTensor setAllInPlace(Integer num) {
        this.tensor = Nd4j.valueArrayOf(getShape(), num.intValue());
        return this;
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public IntegerTensor safeLogTimesInPlace(IntegerTensor integerTensor) {
        throw new NotImplementedException("");
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public IntegerTensor greaterThanMask(IntegerTensor integerTensor) {
        return greaterThan((Nd4jIntegerTensor) integerTensor).toIntegerMask();
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public IntegerTensor greaterThanOrEqualToMask(IntegerTensor integerTensor) {
        return greaterThanOrEqual((Nd4jIntegerTensor) integerTensor).toIntegerMask();
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public IntegerTensor lessThanMask(IntegerTensor integerTensor) {
        return lessThan((Nd4jIntegerTensor) integerTensor).toIntegerMask();
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public IntegerTensor lessThanOrEqualToMask(IntegerTensor integerTensor) {
        return lessThanOrEqual((Nd4jIntegerTensor) integerTensor).toIntegerMask();
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public DoubleTensor toDouble() {
        return new Nd4jDoubleTensor(this.tensor.castTo(DataType.DOUBLE));
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public IntegerTensor toInteger() {
        return (IntegerTensor) duplicate();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.improbable.keanu.tensor.ndj4.Nd4jTensor
    public IntegerTensor create(INDArray iNDArray) {
        return new Nd4jIntegerTensor(iNDArray);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.improbable.keanu.tensor.ndj4.Nd4jTensor
    public IntegerTensor set(INDArray iNDArray) {
        this.tensor = iNDArray.dataType() == DataType.INT ? iNDArray : iNDArray.castTo(DataType.INT);
        return this;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.improbable.keanu.tensor.ndj4.Nd4jTensor
    public IntegerTensor getThis() {
        return this;
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public double[] asFlatDoubleArray() {
        return this.tensor.dup().data().asDouble();
    }

    @Override // io.improbable.keanu.tensor.NumberTensor
    public int[] asFlatIntegerArray() {
        return this.tensor.dup().data().asInt();
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public Integer[] asFlatArray() {
        return ArrayUtils.toObject(asFlatIntegerArray());
    }

    @Override // io.improbable.keanu.tensor.Tensor
    public Tensor.FlattenedView<Integer> getFlattenedView() {
        return new Nd4jIntegerFlattenedView();
    }

    static {
        INDArrayShim.startNewThreadForNd4j();
        BUFFER_TYPE = DataType.INT;
    }
}
