/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor.dbl;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.dbl.JVMDoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.tensor.intgr.Nd4jIntegerTensor;
import io.improbable.keanu.tensor.ndj4.INDArrayExtensions;
import io.improbable.keanu.tensor.ndj4.INDArrayShim;
import io.improbable.keanu.tensor.ndj4.Nd4jFloatingPointTensor;
import io.improbable.keanu.tensor.ndj4.Nd4jTensor;
import io.improbable.keanu.tensor.ndj4.TypedINDArrayFactory;
import io.improbable.keanu.tensor.validate.TensorValidator;
import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.special.Gamma;
import org.bytedeco.javacpp.indexer.BooleanIndexer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;

public class Nd4jDoubleTensor
extends Nd4jFloatingPointTensor<Double, DoubleTensor>
implements DoubleTensor {
    private static final DataType BUFFER_TYPE;

    public Nd4jDoubleTensor(double[] data, long[] shape) {
        this(TypedINDArrayFactory.create(data, shape));
    }

    public Nd4jDoubleTensor(INDArray tensor) {
        super(tensor);
    }

    @Override
    protected INDArray getTensor(Tensor tensor) {
        return Nd4jDoubleTensor.getAsINDArray(tensor);
    }

    @Override
    protected DoubleTensor create(INDArray tensor) {
        return new Nd4jDoubleTensor(tensor);
    }

    @Override
    protected DoubleTensor set(INDArray tensor) {
        this.tensor = tensor.dataType() == DataType.DOUBLE ? tensor : tensor.castTo(DataType.DOUBLE);
        return this;
    }

    @Override
    protected DoubleTensor getThis() {
        return this;
    }

    public static Nd4jDoubleTensor scalar(double scalarValue) {
        return new Nd4jDoubleTensor(Nd4j.scalar((double)scalarValue));
    }

    public static Nd4jDoubleTensor create(double[] values, long ... shape) {
        long length = TensorShape.getLength(shape);
        if ((long)values.length != length) {
            throw new IllegalArgumentException("Shape " + Arrays.toString(shape) + " does not match buffer size " + values.length);
        }
        return new Nd4jDoubleTensor(values, shape);
    }

    public static Nd4jDoubleTensor create(double value, long ... shape) {
        return new Nd4jDoubleTensor(Nd4j.valueArrayOf((long[])shape, (double)value));
    }

    public static Nd4jDoubleTensor create(double[] values) {
        return Nd4jDoubleTensor.create(values, new long[]{values.length});
    }

    public static Nd4jDoubleTensor ones(long ... shape) {
        return new Nd4jDoubleTensor(TypedINDArrayFactory.ones(shape, BUFFER_TYPE));
    }

    public static Nd4jDoubleTensor eye(long n) {
        return new Nd4jDoubleTensor(TypedINDArrayFactory.eye(n, BUFFER_TYPE));
    }

    public static Nd4jDoubleTensor zeros(long[] shape) {
        return new Nd4jDoubleTensor(TypedINDArrayFactory.zeros(shape, BUFFER_TYPE));
    }

    public static Nd4jDoubleTensor linspace(double start, double end, int numberOfPoints) {
        return new Nd4jDoubleTensor(TypedINDArrayFactory.linspace(start, end, numberOfPoints, BUFFER_TYPE));
    }

    public static Nd4jDoubleTensor arange(double start, double end) {
        return new Nd4jDoubleTensor(TypedINDArrayFactory.arange(start, end));
    }

    public static Nd4jDoubleTensor arange(double start, double end, double stepSize) {
        double stepCount = Math.ceil((end - start) / stepSize);
        INDArray arangeWithStep = TypedINDArrayFactory.arange(0.0, stepCount).muli((Number)stepSize).addi((Number)start);
        return new Nd4jDoubleTensor(arangeWithStep);
    }

    static INDArray getAsINDArray(Tensor that) {
        if (that instanceof Nd4jTensor) {
            INDArray array = ((Nd4jTensor)that).getTensor();
            if (array.dataType() == DataType.DOUBLE) {
                return array;
            }
            return array.castTo(DataType.DOUBLE);
        }
        if (that instanceof NumberTensor) {
            return TypedINDArrayFactory.create(((NumberTensor)that).toDouble().asFlatDoubleArray(), that.getShape());
        }
        throw new IllegalArgumentException("Cannot convert " + that.getClass().getSimpleName() + " to double INDArray/");
    }

    @Override
    public int nanArgMax() {
        return this.tensor.argMax(new int[0]).getInt(new int[]{0});
    }

    @Override
    public int argMax() {
        return ((DoubleTensor)((DoubleTensor)this.duplicate()).replaceNaNInPlace(Double.MAX_VALUE)).nanArgMax();
    }

    @Override
    public IntegerTensor nanArgMax(int axis) {
        long[] shape = this.getShape();
        TensorShapeValidation.checkDimensionExistsInShape(axis, shape);
        INDArray max = this.tensor.argMax(new int[]{axis}).reshape(TensorShape.removeDimension(axis, shape));
        return new Nd4jIntegerTensor(max);
    }

    @Override
    public IntegerTensor argMax(int axis) {
        return ((DoubleTensor)((DoubleTensor)this.duplicate()).replaceNaNInPlace(Double.MAX_VALUE)).nanArgMax(axis);
    }

    @Override
    public int nanArgMin() {
        return Nd4j.argMin((INDArray)this.tensor, (int[])new int[0]).getInt(new int[]{0});
    }

    @Override
    public int argMin() {
        return ((DoubleTensor)((DoubleTensor)this.duplicate()).replaceNaNInPlace(-1.7976931348623157E308)).nanArgMin();
    }

    @Override
    public IntegerTensor nanArgMin(int axis) {
        long[] shape = this.getShape();
        TensorShapeValidation.checkDimensionExistsInShape(axis, shape);
        return new Nd4jIntegerTensor(Nd4j.argMin((INDArray)this.tensor, (int[])new int[]{axis}).reshape(TensorShape.removeDimension(axis, shape)));
    }

    @Override
    public IntegerTensor argMin(int axis) {
        return ((DoubleTensor)((DoubleTensor)this.duplicate()).replaceNaNInPlace(-1.7976931348623157E308)).nanArgMin(axis);
    }

    @Override
    protected Double getNumber(Number number) {
        return number.doubleValue();
    }

    @Override
    public boolean equalsWithinEpsilon(DoubleTensor o, Double epsilon) {
        if (this == o) {
            return true;
        }
        if (o instanceof Nd4jTensor) {
            return this.tensor.equalsWithEps((Object)((Nd4jTensor)((Object)o)).getTensor(), epsilon.doubleValue());
        }
        if (this.hasSameShapeAs(o)) {
            DoubleTensor difference = o.minus(this);
            return ((DoubleTensor)difference.abs()).lessThan(epsilon).allTrue();
        }
        return false;
    }

    @Override
    public DoubleTensor greaterThanMask(DoubleTensor greaterThanThis) {
        return this.greaterThan(greaterThanThis).toDoubleMask();
    }

    @Override
    public DoubleTensor greaterThanOrEqualToMask(DoubleTensor greaterThanOrEqualToThis) {
        return this.greaterThanOrEqual(greaterThanOrEqualToThis).toDoubleMask();
    }

    @Override
    public DoubleTensor lessThanMask(DoubleTensor lessThanThis) {
        return this.lessThan(lessThanThis).toDoubleMask();
    }

    @Override
    public DoubleTensor lessThanOrEqualToMask(DoubleTensor lessThanOrEqualToThis) {
        return this.lessThanOrEqual(lessThanOrEqualToThis).toDoubleMask();
    }

    @Override
    public DoubleTensor safeLogTimesInPlace(DoubleTensor y) {
        TensorValidator.NAN_CATCHER.validate(this.getThis());
        TensorValidator.NAN_CATCHER.validate(y);
        DoubleTensor result = ((DoubleTensor)this.logInPlace()).timesInPlace(y);
        return TensorValidator.NAN_FIXER.validate(result);
    }

    @Override
    public DoubleTensor logGammaInPlace() {
        return (DoubleTensor)this.applyInPlace(Gamma::logGamma);
    }

    @Override
    public DoubleTensor digammaInPlace() {
        return (DoubleTensor)this.applyInPlace(Gamma::digamma);
    }

    @Override
    public DoubleTensor logAddExp2InPlace(DoubleTensor that) {
        JVMDoubleTensor asJVM = JVMDoubleTensor.create(this.tensor.toDoubleVector(), this.tensor.shape());
        DoubleTensor result = asJVM.logAddExp2InPlace(that);
        return Nd4jDoubleTensor.create(result.asFlatDoubleArray(), result.getShape());
    }

    @Override
    public DoubleTensor logAddExpInPlace(DoubleTensor that) {
        JVMDoubleTensor asJVM = JVMDoubleTensor.create(this.tensor.toDoubleVector(), this.tensor.shape());
        DoubleTensor result = asJVM.logAddExpInPlace(that);
        return Nd4jDoubleTensor.create(result.asFlatDoubleArray(), result.getShape());
    }

    @Override
    public DoubleTensor replaceNaNInPlace(Double value) {
        Nd4j.getExecutioner().exec((ScalarOp)new ReplaceNans(this.tensor, value.doubleValue()));
        return this;
    }

    @Override
    public BooleanTensor isFinite() {
        INDArray result = Nd4j.getExecutioner().exec((Op)new MatchConditionTransform(this.tensor, Nd4j.createUninitialized((DataType)DataType.BOOL, (long[])this.tensor.shape(), (char)this.tensor.ordering()), Conditions.isFinite()));
        return BooleanTensor.create(this.asBoolean(result), this.tensor.shape());
    }

    @Override
    public BooleanTensor isInfinite() {
        INDArray result = this.tensor.isInfinite();
        return BooleanTensor.create(this.asBoolean(result), this.tensor.shape());
    }

    private boolean[] asBoolean(INDArray array) {
        Preconditions.checkArgument((array.dataType() == DataType.BOOL ? 1 : 0) != 0);
        boolean[] buffer = new boolean[Ints.checkedCast((long)array.length())];
        for (int i = 0; i < buffer.length; ++i) {
            buffer[i] = ((BooleanIndexer)array.data().indexer()).get((long)i);
        }
        return buffer;
    }

    @Override
    public BooleanTensor isNegativeInfinity() {
        INDArray result = Nd4j.getExecutioner().exec((Op)new MatchConditionTransform(this.tensor, Nd4j.createUninitialized((DataType)DataType.BOOL, (long[])this.tensor.shape(), (char)this.tensor.ordering()), Conditions.equals((Number)Double.NEGATIVE_INFINITY)));
        return BooleanTensor.create(this.asBoolean(result), this.tensor.shape());
    }

    @Override
    public BooleanTensor isPositiveInfinity() {
        INDArray result = Nd4j.getExecutioner().exec((Op)new MatchConditionTransform(this.tensor, Nd4j.createUninitialized((DataType)DataType.BOOL, (long[])this.tensor.shape(), (char)this.tensor.ordering()), Conditions.equals((Number)Double.POSITIVE_INFINITY)));
        return BooleanTensor.create(this.asBoolean(result), this.tensor.shape());
    }

    @Override
    public DoubleTensor toDouble() {
        return (DoubleTensor)this.duplicate();
    }

    @Override
    public IntegerTensor toInteger() {
        return new Nd4jIntegerTensor(INDArrayExtensions.castToInteger(this.tensor, true));
    }

    @Override
    public double[] asFlatDoubleArray() {
        return this.tensor.dup().data().asDouble();
    }

    @Override
    public int[] asFlatIntegerArray() {
        return this.tensor.dup().data().asInt();
    }

    public Double[] asFlatArray() {
        return ArrayUtils.toObject((double[])this.asFlatDoubleArray());
    }

    @Override
    public Tensor.FlattenedView<Double> getFlattenedView() {
        return new Nd4jDoubleFlattenedView();
    }

    static {
        INDArrayShim.startNewThreadForNd4j();
        BUFFER_TYPE = DataType.DOUBLE;
    }

    private class Nd4jDoubleFlattenedView
    implements Tensor.FlattenedView<Double> {
        private Nd4jDoubleFlattenedView() {
        }

        @Override
        public long size() {
            return Nd4jDoubleTensor.this.tensor.data().length();
        }

        @Override
        public Double get(long index) {
            return Nd4jDoubleTensor.this.tensor.data().getDouble(index);
        }

        @Override
        public Double getOrScalar(long index) {
            if (Nd4jDoubleTensor.this.tensor.length() == 1L) {
                return this.get(0L);
            }
            return this.get(index);
        }

        @Override
        public void set(long index, Double value) {
            Nd4jDoubleTensor.this.tensor.data().put(index, value.doubleValue());
        }
    }
}

