/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor.intgr;

import io.improbable.keanu.tensor.NumberTensor;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.dbl.Nd4jDoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.tensor.ndj4.INDArrayShim;
import io.improbable.keanu.tensor.ndj4.Nd4jFixedPointTensor;
import io.improbable.keanu.tensor.ndj4.Nd4jTensor;
import io.improbable.keanu.tensor.ndj4.TypedINDArrayFactory;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.NotImplementedException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class Nd4jIntegerTensor
extends Nd4jFixedPointTensor<Integer, IntegerTensor>
implements IntegerTensor {
    private static final DataType BUFFER_TYPE;

    public Nd4jIntegerTensor(int[] data, long[] shape) {
        this(TypedINDArrayFactory.create(data, shape));
    }

    public Nd4jIntegerTensor(INDArray tensor) {
        super(tensor);
    }

    @Override
    protected INDArray getTensor(Tensor<Integer, ?> tensor) {
        return Nd4jIntegerTensor.getAsINDArray(tensor);
    }

    public static Nd4jIntegerTensor scalar(int scalarValue) {
        return new Nd4jIntegerTensor(Nd4j.scalar((int)scalarValue));
    }

    public static Nd4jIntegerTensor create(int[] values, long[] shape) {
        return new Nd4jIntegerTensor(values, shape);
    }

    public static Nd4jIntegerTensor create(int value, long[] shape) {
        return new Nd4jIntegerTensor(Nd4j.valueArrayOf((long[])shape, (int)value));
    }

    public static Nd4jIntegerTensor ones(long[] shape) {
        return new Nd4jIntegerTensor(TypedINDArrayFactory.ones(shape, BUFFER_TYPE));
    }

    public static Nd4jIntegerTensor eye(long n) {
        return new Nd4jIntegerTensor(TypedINDArrayFactory.eye(n, BUFFER_TYPE));
    }

    public static Nd4jIntegerTensor zeros(long[] shape) {
        return new Nd4jIntegerTensor(TypedINDArrayFactory.zeros(shape, BUFFER_TYPE));
    }

    public static Nd4jIntegerTensor arange(int start, int end) {
        return new Nd4jIntegerTensor(TypedINDArrayFactory.arange(start, end));
    }

    public static Nd4jDoubleTensor arange(int start, int end, int stepSize) {
        int stepCount = (end - start) / stepSize;
        INDArray arangeWithStep = TypedINDArrayFactory.arange(0, stepCount).muli((Number)stepSize).addi((Number)start);
        return new Nd4jDoubleTensor(arangeWithStep);
    }

    static INDArray getAsINDArray(Tensor<Integer, ?> that) {
        if (that instanceof Nd4jTensor) {
            INDArray array = ((Nd4jTensor)that).getTensor();
            if (array.dataType() == DataType.INT) {
                return array;
            }
            return array.castTo(DataType.INT);
        }
        if (that instanceof NumberTensor) {
            return TypedINDArrayFactory.create(((NumberTensor)that).toInteger().asFlatIntegerArray(), that.getShape());
        }
        throw new IllegalArgumentException("Cannot convert " + that.getClass().getSimpleName() + " to double INDArray/");
    }

    @Override
    public boolean equalsWithinEpsilon(IntegerTensor o, Integer epsilon) {
        if (this == o) {
            return true;
        }
        if (o instanceof Nd4jTensor) {
            return this.tensor.equalsWithEps((Object)((Nd4jTensor)((Object)o)).getTensor(), (double)epsilon.intValue());
        }
        if (this.hasSameShapeAs(o)) {
            IntegerTensor difference = o.minus(this);
            return ((IntegerTensor)difference.abs()).lessThan(epsilon).allTrue();
        }
        return false;
    }

    @Override
    protected Integer getNumber(Number number) {
        return number.intValue();
    }

    @Override
    public IntegerTensor setAllInPlace(Integer value) {
        this.tensor = Nd4j.valueArrayOf((long[])this.getShape(), (int)value);
        return this;
    }

    @Override
    public IntegerTensor safeLogTimesInPlace(IntegerTensor y) {
        throw new NotImplementedException("");
    }

    @Override
    public IntegerTensor greaterThanMask(IntegerTensor greaterThanThis) {
        return this.greaterThan(greaterThanThis).toIntegerMask();
    }

    @Override
    public IntegerTensor greaterThanOrEqualToMask(IntegerTensor greaterThanOrEqualToThis) {
        return this.greaterThanOrEqual(greaterThanOrEqualToThis).toIntegerMask();
    }

    @Override
    public IntegerTensor lessThanMask(IntegerTensor lessThanThis) {
        return this.lessThan(lessThanThis).toIntegerMask();
    }

    @Override
    public IntegerTensor lessThanOrEqualToMask(IntegerTensor lessThanOrEqualToThis) {
        return this.lessThanOrEqual(lessThanOrEqualToThis).toIntegerMask();
    }

    @Override
    public DoubleTensor toDouble() {
        return new Nd4jDoubleTensor(this.tensor.castTo(DataType.DOUBLE));
    }

    @Override
    public IntegerTensor toInteger() {
        return (IntegerTensor)this.duplicate();
    }

    @Override
    protected IntegerTensor create(INDArray tensor) {
        return new Nd4jIntegerTensor(tensor);
    }

    @Override
    protected IntegerTensor set(INDArray tensor) {
        this.tensor = tensor.dataType() == DataType.INT ? tensor : tensor.castTo(DataType.INT);
        return this;
    }

    @Override
    protected IntegerTensor getThis() {
        return this;
    }

    @Override
    public double[] asFlatDoubleArray() {
        return this.tensor.dup().data().asDouble();
    }

    @Override
    public int[] asFlatIntegerArray() {
        return this.tensor.dup().data().asInt();
    }

    public Integer[] asFlatArray() {
        return ArrayUtils.toObject((int[])this.asFlatIntegerArray());
    }

    @Override
    public Tensor.FlattenedView<Integer> getFlattenedView() {
        return new Nd4jIntegerFlattenedView();
    }

    static {
        INDArrayShim.startNewThreadForNd4j();
        BUFFER_TYPE = DataType.INT;
    }

    private class Nd4jIntegerFlattenedView
    implements Tensor.FlattenedView<Integer> {
        private Nd4jIntegerFlattenedView() {
        }

        @Override
        public long size() {
            return Nd4jIntegerTensor.this.tensor.data().length();
        }

        @Override
        public Integer get(long index) {
            return Nd4jIntegerTensor.this.tensor.data().getInt(index);
        }

        @Override
        public Integer getOrScalar(long index) {
            if (Nd4jIntegerTensor.this.tensor.length() == 1L) {
                return this.get(0L);
            }
            return this.get(index);
        }

        @Override
        public void set(long index, Integer value) {
            Nd4jIntegerTensor.this.tensor.data().put(index, value.intValue());
        }
    }
}

