/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor.ndj4;

import com.google.common.primitives.Ints;
import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.bool.JVMBooleanTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.tensor.intgr.Nd4jIntegerTensor;
import io.improbable.keanu.tensor.ndj4.INDArrayExtensions;
import io.improbable.keanu.tensor.ndj4.INDArrayShim;
import io.improbable.keanu.tensor.ndj4.Nd4jTensor;
import java.util.function.Function;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms;

public abstract class Nd4jNumberTensor<T extends Number, TENSOR extends NumberTensor<T, TENSOR>>
extends Nd4jTensor<T, TENSOR>
implements NumberTensor<T, TENSOR> {
    public Nd4jNumberTensor(INDArray tensor) {
        super(tensor);
    }

    @Override
    public TENSOR setWithMaskInPlace(TENSOR mask, T value) {
        if (this.getLength() != mask.getLength()) {
            throw new IllegalArgumentException("The lengths of the tensor and mask must match, but got tensor length: " + this.getLength() + ", mask length: " + mask.getLength());
        }
        INDArray maskINDArray = this.getTensor(mask);
        INDArray dblBuffer = this.tensor.dataType() == DataType.DOUBLE ? this.tensor : this.tensor.castTo(DataType.DOUBLE);
        INDArray dblMask = maskINDArray.dataType() == DataType.DOUBLE ? maskINDArray : maskINDArray.castTo(DataType.DOUBLE);
        double dblValue = ((Number)value).doubleValue();
        double trueValue = 1.0;
        if (dblValue == 0.0) {
            trueValue = 1.0 - trueValue;
            dblMask.negi().addi((Number)1.0);
        }
        double falseValue = 1.0 - trueValue;
        Nd4j.getExecutioner().exec((Op)new CompareAndSet(dblMask, dblValue, Conditions.equals((Number)trueValue)));
        Nd4j.getExecutioner().exec((Op)new CompareAndSet(dblBuffer, dblMask, Conditions.notEquals((Number)falseValue)));
        return (TENSOR)((NumberTensor)this.set(dblBuffer));
    }

    @Override
    public TENSOR sum(int ... overDimensions) {
        if (overDimensions.length == 0) {
            return (TENSOR)((NumberTensor)this.duplicate());
        }
        return (TENSOR)((NumberTensor)this.create(this.tensor.sum(overDimensions)));
    }

    @Override
    public T sum() {
        return this.getNumber(this.tensor.sumNumber());
    }

    @Override
    public TENSOR cumSumInPlace(int requestedDimension) {
        int dimension = requestedDimension >= 0 ? requestedDimension : requestedDimension + this.tensor.rank();
        TensorShapeValidation.checkDimensionExistsInShape(dimension, this.tensor.shape());
        return (TENSOR)((NumberTensor)this.set(this.tensor.cumsumi(dimension)));
    }

    @Override
    public T product() {
        return this.getNumber(this.tensor.prodNumber());
    }

    @Override
    public TENSOR product(int ... overDimensions) {
        return (TENSOR)((NumberTensor)this.create(this.tensor.prod(overDimensions)));
    }

    @Override
    public TENSOR cumProdInPlace(int requestedDimension) {
        int dimension = requestedDimension >= 0 ? requestedDimension : requestedDimension + this.tensor.rank();
        TensorShapeValidation.checkDimensionExistsInShape(dimension, this.tensor.shape());
        return (TENSOR)((NumberTensor)this.set(INDArrayExtensions.cumProd(this.tensor, dimension)));
    }

    @Override
    public TENSOR clampInPlace(TENSOR min, TENSOR max) {
        return this.minInPlace(max).maxInPlace(min);
    }

    @Override
    public T max() {
        return this.getNumber(this.tensor.maxNumber());
    }

    @Override
    public T min() {
        return this.getNumber(this.tensor.minNumber());
    }

    @Override
    public TENSOR maxInPlace(TENSOR max) {
        if (max.isScalar()) {
            Nd4j.getExecutioner().exec((ScalarOp)new ScalarMax(this.tensor, (Number)max.scalar()));
        } else {
            this.tensor = INDArrayShim.max(this.tensor, this.getTensor(max));
        }
        return (TENSOR)((NumberTensor)this.set(this.tensor));
    }

    @Override
    public TENSOR minInPlace(TENSOR min) {
        if (min.isScalar()) {
            Nd4j.getExecutioner().exec((ScalarOp)new ScalarMin(this.tensor, (Number)min.scalar()));
        } else {
            this.tensor = INDArrayShim.min(this.tensor, this.getTensor(min));
        }
        return (TENSOR)((NumberTensor)this.set(this.tensor));
    }

    @Override
    public int argMax() {
        return this.tensor.argMax(new int[0]).getInt(new int[]{0});
    }

    @Override
    public IntegerTensor argMax(int axis) {
        long[] shape = this.getShape();
        TensorShapeValidation.checkDimensionExistsInShape(axis, shape);
        INDArray max = this.tensor.argMax(new int[]{axis}).reshape(TensorShape.removeDimension(axis, shape));
        return new Nd4jIntegerTensor(max);
    }

    @Override
    public int argMin() {
        return Nd4j.argMin((INDArray)this.tensor, (int[])new int[0]).getInt(new int[]{0});
    }

    @Override
    public IntegerTensor argMin(int axis) {
        long[] shape = this.getShape();
        TensorShapeValidation.checkDimensionExistsInShape(axis, shape);
        return new Nd4jIntegerTensor(Nd4j.argMin((INDArray)this.tensor, (int[])new int[]{axis}).reshape(TensorShape.removeDimension(axis, shape)));
    }

    @Override
    public TENSOR minus(TENSOR that) {
        if (that.isScalar()) {
            return (TENSOR)this.minus((TENSOR)((Number)that.scalar()));
        }
        if (this.isScalar()) {
            return (TENSOR)that.reverseMinus((Number)((Number)this.scalar()));
        }
        return ((NumberTensor)this.duplicate()).minusInPlace(that);
    }

    @Override
    public TENSOR plus(TENSOR that) {
        if (that.isScalar()) {
            return (TENSOR)this.plus((TENSOR)((Number)that.scalar()));
        }
        if (this.isScalar()) {
            return (TENSOR)that.plus((Number)((Number)this.scalar()));
        }
        return ((NumberTensor)this.duplicate()).plusInPlace(that);
    }

    @Override
    public TENSOR times(TENSOR that) {
        if (that.isScalar()) {
            return (TENSOR)this.times((TENSOR)((Number)that.scalar()));
        }
        if (this.isScalar()) {
            return (TENSOR)that.times((Number)((Number)this.scalar()));
        }
        return ((NumberTensor)this.duplicate()).timesInPlace(that);
    }

    @Override
    public TENSOR div(TENSOR that) {
        if (that.isScalar()) {
            return (TENSOR)this.div((TENSOR)((Number)that.scalar()));
        }
        if (this.isScalar()) {
            return (TENSOR)that.reverseDiv((Number)((Number)this.scalar()));
        }
        return ((NumberTensor)this.duplicate()).divInPlace(that);
    }

    @Override
    public TENSOR minusInPlace(T value) {
        return (TENSOR)((NumberTensor)this.set(this.tensor.subi(value)));
    }

    @Override
    public TENSOR plusInPlace(T value) {
        return (TENSOR)((NumberTensor)this.set(this.tensor.addi(value)));
    }

    @Override
    public TENSOR timesInPlace(T value) {
        return (TENSOR)((NumberTensor)this.set(this.tensor.muli(value)));
    }

    @Override
    public TENSOR divInPlace(T value) {
        return (TENSOR)((NumberTensor)this.set(this.tensor.divi(value)));
    }

    @Override
    public TENSOR powInPlace(TENSOR exponent) {
        if (exponent.isScalar()) {
            Transforms.pow((INDArray)this.tensor, (Number)((Number)exponent.scalar()), (boolean)false);
        } else {
            INDArray exponentArray = this.getTensor(exponent);
            this.tensor = INDArrayShim.pow(this.tensor, exponentArray);
        }
        return (TENSOR)((NumberTensor)this.set(this.tensor));
    }

    @Override
    public TENSOR powInPlace(T exponent) {
        Transforms.pow((INDArray)this.tensor, exponent, (boolean)false);
        return (TENSOR)((NumberTensor)this.set(this.tensor));
    }

    @Override
    public TENSOR minusInPlace(TENSOR that) {
        if (that.isScalar()) {
            this.tensor.subi((Number)that.scalar());
        } else {
            if (this.isScalar()) {
                return this.minus(that);
            }
            INDArray result = INDArrayShim.subi(this.tensor, this.getTensor(that).dup());
            if (result != this.tensor) {
                return (TENSOR)((NumberTensor)this.create(result));
            }
        }
        return (TENSOR)((NumberTensor)this.set(this.tensor));
    }

    @Override
    public TENSOR reverseMinusInPlace(TENSOR that) {
        if (that.isScalar()) {
            this.tensor.subi((Number)that.scalar());
        } else {
            if (this.isScalar()) {
                return this.minus(that);
            }
            INDArray result = INDArrayShim.rsubi(this.tensor, this.getTensor(that).dup());
            if (result != this.tensor) {
                return (TENSOR)((NumberTensor)this.create(result));
            }
        }
        return (TENSOR)((NumberTensor)this.set(this.tensor));
    }

    @Override
    public TENSOR reverseMinusInPlace(T value) {
        return (TENSOR)((NumberTensor)this.set(this.tensor.rsubi(value)));
    }

    @Override
    public TENSOR plusInPlace(TENSOR that) {
        if (that.isScalar()) {
            this.tensor.addi((Number)that.scalar());
        } else {
            if (this.isScalar()) {
                return this.plus(that);
            }
            INDArray result = INDArrayShim.addi(this.tensor, this.getTensor(that).dup());
            if (result != this.tensor) {
                return (TENSOR)((NumberTensor)this.create(result));
            }
        }
        return (TENSOR)((NumberTensor)this.set(this.tensor));
    }

    @Override
    public TENSOR timesInPlace(TENSOR that) {
        if (that.isScalar()) {
            this.tensor.muli((Number)that.scalar());
        } else {
            if (this.isScalar()) {
                return this.times(that);
            }
            INDArray result = INDArrayShim.muli(this.tensor, this.getTensor(that).dup());
            if (result != this.tensor) {
                return (TENSOR)((NumberTensor)this.create(result));
            }
        }
        return (TENSOR)((NumberTensor)this.set(this.tensor));
    }

    @Override
    public TENSOR divInPlace(TENSOR that) {
        if (that.isScalar()) {
            this.tensor.divi((Number)that.scalar());
        } else {
            if (this.isScalar()) {
                return this.div(that);
            }
            INDArray result = INDArrayShim.divi(this.tensor, this.getTensor(that).dup());
            if (result != this.tensor) {
                return (TENSOR)((NumberTensor)this.create(result));
            }
        }
        return (TENSOR)((NumberTensor)this.set(this.tensor));
    }

    @Override
    public TENSOR reverseDivInPlace(T value) {
        return (TENSOR)((NumberTensor)this.set(this.tensor.rdivi(value)));
    }

    @Override
    public TENSOR reverseDivInPlace(TENSOR that) {
        if (that.isScalar()) {
            this.tensor.subi((Number)that.scalar());
        } else {
            if (this.isScalar()) {
                return this.minus(that);
            }
            INDArray result = INDArrayShim.rdivi(this.tensor, this.getTensor(that).dup());
            if (result != this.tensor) {
                return (TENSOR)((NumberTensor)this.create(result));
            }
        }
        return (TENSOR)((NumberTensor)this.set(this.tensor));
    }

    @Override
    public TENSOR unaryMinusInPlace() {
        return (TENSOR)((NumberTensor)this.set(this.tensor.negi()));
    }

    @Override
    public TENSOR absInPlace() {
        Transforms.abs((INDArray)this.tensor, (boolean)false);
        return (TENSOR)((NumberTensor)this.set(this.tensor));
    }

    @Override
    public T average() {
        return this.getNumber(this.tensor.meanNumber());
    }

    @Override
    public T standardDeviation() {
        return this.getNumber(this.tensor.stdNumber());
    }

    @Override
    public BooleanTensor greaterThan(T value) {
        return this.fromMask(this.tensor.gt(value));
    }

    @Override
    public BooleanTensor greaterThanOrEqual(T value) {
        return this.fromMask(this.tensor.gte(value));
    }

    @Override
    public BooleanTensor greaterThan(TENSOR value) {
        INDArray mask;
        if (value.isScalar()) {
            mask = this.tensor.gt((Number)value.scalar());
        } else {
            INDArray indArray = this.getTensor(value);
            mask = INDArrayShim.gt(this.tensor, indArray);
        }
        return this.fromMask(mask);
    }

    @Override
    public BooleanTensor greaterThanOrEqual(TENSOR value) {
        INDArray mask;
        if (value.isScalar()) {
            mask = this.tensor.gte((Number)value.scalar());
        } else {
            INDArray indArray = this.getTensor(value);
            mask = this.tensor.dup();
            mask = INDArrayShim.gte(mask, indArray);
        }
        return this.fromMask(mask);
    }

    @Override
    public BooleanTensor lessThan(T value) {
        return this.fromMask(this.tensor.lt(value));
    }

    @Override
    public BooleanTensor lessThanOrEqual(T value) {
        return this.fromMask(this.tensor.lte(value));
    }

    @Override
    public BooleanTensor lessThan(TENSOR value) {
        INDArray mask;
        if (value.isScalar()) {
            mask = this.tensor.lt((Number)value.scalar());
        } else {
            INDArray indArray = this.getTensor(value);
            mask = INDArrayShim.lt(this.tensor, indArray);
        }
        return this.fromMask(mask);
    }

    @Override
    public BooleanTensor lessThanOrEqual(TENSOR value) {
        INDArray mask;
        if (value.isScalar()) {
            mask = this.tensor.lte((Number)value.scalar());
        } else {
            INDArray indArray = this.getTensor(value);
            mask = this.tensor.dup();
            mask = INDArrayShim.lte(mask, indArray);
        }
        return this.fromMask(mask);
    }

    @Override
    public TENSOR applyInPlace(Function<T, T> function) {
        Tensor.FlattenedView<T> flattenedView = this.getFlattenedView();
        int i = 0;
        while ((long)i < flattenedView.size()) {
            flattenedView.set(i, function.apply(flattenedView.get(i)));
            ++i;
        }
        return (TENSOR)((NumberTensor)this.getThis());
    }

    @Override
    public BooleanTensor elementwiseEquals(Tensor that) {
        if (that instanceof NumberTensor) {
            if (this.isScalar()) {
                return that.elementwiseEquals(this.scalar());
            }
            if (that.isScalar()) {
                return this.elementwiseEquals((T)((Number)that.scalar()));
            }
            INDArray mask = INDArrayShim.eq(this.tensor, this.getTensor(that));
            return this.fromMask(mask);
        }
        return Tensor.elementwiseEquals(this, that);
    }

    @Override
    public BooleanTensor elementwiseEquals(T value) {
        return this.fromMask(this.tensor.eq(value));
    }

    @Override
    public TENSOR matrixMultiply(TENSOR value) {
        return (TENSOR)((NumberTensor)this.set(this.tensor.mmul(this.getTensor(value))));
    }

    @Override
    public TENSOR tensorMultiply(TENSOR value, int[] dimLeft, int[] dimsRight) {
        return (TENSOR)((NumberTensor)this.set(Nd4j.tensorMmul((INDArray)this.tensor, (INDArray)this.getTensor(value), (int[][])new int[][]{dimLeft, dimsRight})));
    }

    protected final BooleanTensor fromMask(INDArray mask) {
        long[] shape = mask.shape();
        DataBuffer data = mask.data();
        boolean[] boolsFromMask = new boolean[Ints.checkedCast((long)mask.length())];
        for (int i = 0; i < boolsFromMask.length; ++i) {
            boolsFromMask[i] = data.getNumber((long)i).intValue() != 0;
        }
        return JVMBooleanTensor.create(boolsFromMask, shape);
    }

    protected abstract T getNumber(Number var1);
}

