/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor.intgr;

import com.google.common.primitives.Ints;
import io.improbable.keanu.kotlin.IntegerOperators;
import io.improbable.keanu.tensor.FixedPointTensor;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.intgr.Nd4jIntegerTensor;
import io.improbable.keanu.tensor.intgr.ScalarIntegerTensor;
import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public interface IntegerTensor
extends FixedPointTensor<Integer, IntegerTensor>,
IntegerOperators<IntegerTensor> {
    public static IntegerTensor create(int value, long[] shape) {
        if (Arrays.equals(shape, Tensor.SCALAR_SHAPE)) {
            return new ScalarIntegerTensor(value);
        }
        return Nd4jIntegerTensor.create(value, shape);
    }

    public static IntegerTensor create(int[] values, long ... shape) {
        if (Arrays.equals(shape, Tensor.SCALAR_SHAPE) && values.length == 1) {
            return new ScalarIntegerTensor(values[0]);
        }
        return Nd4jIntegerTensor.create(values, shape);
    }

    public static IntegerTensor create(int ... values) {
        return IntegerTensor.create(values, new long[]{values.length});
    }

    public static IntegerTensor create(long[] values, long ... shape) {
        int[] ints = Arrays.stream(values).mapToInt(Ints::checkedCast).toArray();
        return IntegerTensor.create(ints, shape);
    }

    public static IntegerTensor create(long ... values) {
        return IntegerTensor.create(values, new long[]{values.length});
    }

    public static IntegerTensor ones(long ... shape) {
        if (Arrays.equals(shape, Tensor.SCALAR_SHAPE)) {
            return new ScalarIntegerTensor(1);
        }
        return Nd4jIntegerTensor.ones(shape);
    }

    public static IntegerTensor eye(int n) {
        if (n == 1) {
            return new ScalarIntegerTensor(1);
        }
        return Nd4jIntegerTensor.eye((long)n);
    }

    public static IntegerTensor zeros(long ... shape) {
        if (Arrays.equals(shape, Tensor.SCALAR_SHAPE)) {
            return new ScalarIntegerTensor(0);
        }
        return Nd4jIntegerTensor.zeros(shape);
    }

    public static IntegerTensor scalar(int scalarValue) {
        return new ScalarIntegerTensor(scalarValue);
    }

    public static IntegerTensor vector(int ... values) {
        return IntegerTensor.create(values, new long[]{values.length});
    }

    public static IntegerTensor stack(int dimension, IntegerTensor ... toStack) {
        long[] shape = toStack[0].getShape();
        int absoluteDimension = TensorShape.getAbsoluteDimension(dimension, shape.length + 1);
        long[] stackedShape = ArrayUtils.insert((int)absoluteDimension, (long[])shape, (long[])new long[]{1L});
        IntegerTensor[] reshaped = new IntegerTensor[toStack.length];
        for (int i = 0; i < toStack.length; ++i) {
            reshaped[i] = (IntegerTensor)toStack[i].reshape(stackedShape);
        }
        return IntegerTensor.concat(absoluteDimension, reshaped);
    }

    public static IntegerTensor concat(IntegerTensor ... toConcat) {
        return IntegerTensor.concat(0, toConcat);
    }

    public static IntegerTensor concat(int dimension, IntegerTensor ... toConcat) {
        INDArray[] concatAsINDArray = new INDArray[toConcat.length];
        for (int i = 0; i < toConcat.length; ++i) {
            concatAsINDArray[i] = Nd4jIntegerTensor.getAsINDArray(toConcat[i]).dup();
            if (concatAsINDArray[i].shape().length != 0) continue;
            concatAsINDArray[i] = concatAsINDArray[i].reshape(new long[]{1L});
        }
        INDArray concat = Nd4j.concat((int)dimension, (INDArray[])concatAsINDArray);
        return new Nd4jIntegerTensor(concat);
    }

    public static IntegerTensor min(IntegerTensor a, IntegerTensor b) {
        return ((IntegerTensor)a.duplicate()).minInPlace(b);
    }

    public static IntegerTensor max(IntegerTensor a, IntegerTensor b) {
        return ((IntegerTensor)a.duplicate()).maxInPlace(b);
    }

    @Override
    default public IntegerTensor plus(int value) {
        return (IntegerTensor)((Object)this.plus(Integer.valueOf(value)));
    }

    @Override
    default public IntegerTensor minus(int value) {
        return (IntegerTensor)((Object)this.minus(Integer.valueOf(value)));
    }

    @Override
    default public IntegerTensor reverseMinus(int value) {
        return (IntegerTensor)this.reverseMinus(Integer.valueOf(value));
    }

    @Override
    default public IntegerTensor times(int value) {
        return (IntegerTensor)((Object)this.times(Integer.valueOf(value)));
    }

    @Override
    default public IntegerTensor div(int value) {
        return (IntegerTensor)((Object)this.div(Integer.valueOf(value)));
    }

    @Override
    default public IntegerTensor reverseDiv(int value) {
        return (IntegerTensor)((Object)this.reverseDiv(Integer.valueOf(value)));
    }

    @Override
    default public IntegerTensor pow(int exponent) {
        return (IntegerTensor)this.pow(Integer.valueOf(exponent));
    }
}

