/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Longs;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.generic.GenericTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.tensor.jvm.Slicer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;

public interface Tensor<N, T extends Tensor<N, T>> {
    public static final long[] SCALAR_SHAPE = new long[0];
    public static final long[] SCALAR_STRIDE = new long[0];
    public static final long[] ONE_BY_ONE_SHAPE = new long[]{1L, 1L};

    public static <DATA, TENSOR extends Tensor<DATA, TENSOR>> TENSOR scalar(DATA data) {
        if (data instanceof Double) {
            return (TENSOR)DoubleTensor.scalar((Double)data);
        }
        if (data instanceof Integer) {
            return (TENSOR)IntegerTensor.scalar((Integer)data);
        }
        if (data instanceof Boolean) {
            return (TENSOR)BooleanTensor.scalar((Boolean)data);
        }
        return (TENSOR)GenericTensor.scalar(data);
    }

    public static <DATA, TENSOR extends Tensor<DATA, TENSOR>> TENSOR createFilled(DATA data, long[] shape) {
        if (data instanceof Double) {
            return (TENSOR)DoubleTensor.create((double)((Double)data), shape);
        }
        if (data instanceof Integer) {
            return (TENSOR)IntegerTensor.create((int)((Integer)data), shape);
        }
        if (data instanceof Boolean) {
            return (TENSOR)BooleanTensor.create((boolean)((Boolean)data), shape);
        }
        return (TENSOR)GenericTensor.createFilled(data, shape);
    }

    public static <DATA, TENSOR extends Tensor<DATA, TENSOR>> TENSOR create(DATA[] data, long[] shape) {
        if (data instanceof Double[]) {
            return (TENSOR)DoubleTensor.create(ArrayUtils.toPrimitive((Double[])((Double[])data)), shape);
        }
        if (data instanceof Integer[]) {
            return (TENSOR)IntegerTensor.create(ArrayUtils.toPrimitive((Integer[])((Integer[])data)), shape);
        }
        if (data instanceof Boolean[]) {
            return (TENSOR)BooleanTensor.create(ArrayUtils.toPrimitive((Boolean[])((Boolean[])data)), shape);
        }
        return (TENSOR)GenericTensor.create(data, shape);
    }

    public static BooleanTensor elementwiseEquals(Tensor a, Tensor b) {
        if (!a.hasSameShapeAs(b)) {
            throw new IllegalArgumentException(String.format("Cannot compare tensors of different shapes %s and %s", Arrays.toString(a.getShape()), Arrays.toString(b.getShape())));
        }
        N[] aArray = a.asFlatArray();
        N[] bArray = b.asFlatArray();
        boolean[] equality = new boolean[aArray.length];
        for (int i = 0; i < aArray.length; ++i) {
            equality[i] = aArray[i].equals(bArray[i]);
        }
        long[] shape = a.getShape();
        return BooleanTensor.create(equality, Arrays.copyOf(shape, shape.length));
    }

    public int getRank();

    public long[] getShape();

    public long[] getStride();

    public long getLength();

    default public N getValue(long ... index) {
        if (index.length == 1) {
            return this.getFlattenedView().get(index[0]);
        }
        return this.getFlattenedView().get(TensorShape.getFlatIndex(this.getShape(), this.getStride(), index));
    }

    public T get(BooleanTensor var1);

    default public void setValue(N value, long ... index) {
        if (index.length == 1) {
            this.getFlattenedView().set(index[0], value);
        } else {
            this.getFlattenedView().set(TensorShape.getFlatIndex(this.getShape(), this.getStride(), index), value);
        }
    }

    default public N scalar() {
        if (this.getLength() > 1L) {
            throw new IllegalArgumentException("Not a scalar");
        }
        return this.getValue(0L);
    }

    public T duplicate();

    public T slice(int var1, long var2);

    default public T slice(String sliceArg) {
        return this.slice(Slicer.fromString(sliceArg));
    }

    public T slice(Slicer var1);

    public T take(long ... var1);

    public List<T> split(int var1, long ... var2);

    default public List<T> sliceAlongDimension(int dimension, long indexStart, long indexEnd) {
        ArrayList<T> slicedTensors = new ArrayList<T>();
        for (long i = indexStart; i < indexEnd; ++i) {
            slicedTensors.add(this.slice(dimension, i));
        }
        return slicedTensors;
    }

    public T diag();

    default public T transpose() {
        Preconditions.checkArgument((this.getRank() == 2 ? 1 : 0) != 0, (Object)"Can only transpose rank 2. Use permute(...) for higher rank transpose.");
        return this.permute(1, 0);
    }

    public N[] asFlatArray();

    public T reshape(long ... var1);

    default public T squeeze() {
        long[] shape = this.getShape();
        ArrayList<Long> squeezedShape = new ArrayList<Long>();
        for (long length : shape) {
            if (length <= 1L) continue;
            squeezedShape.add(length);
        }
        return this.reshape(Longs.toArray(squeezedShape));
    }

    default public T expandDims(int axis) {
        long[] shape = this.getShape();
        return this.reshape(ArrayUtils.insert((int)axis, (long[])shape, (long[])new long[]{1L}));
    }

    default public T moveAxis(int source, int destination) {
        int[] dimensionRange = TensorShape.dimensionRange(0, this.getRank());
        source = TensorShape.getAbsoluteDimension(source, dimensionRange.length);
        destination = TensorShape.getAbsoluteDimension(destination, dimensionRange.length);
        int[] rearrange = ArrayUtils.insert((int)destination, (int[])ArrayUtils.remove((int[])dimensionRange, (int)source), (int[])new int[]{source});
        return this.permute(rearrange);
    }

    default public T swapAxis(int axis1, int axis2) {
        int[] rearrange = TensorShape.dimensionRange(0, this.getRank());
        axis1 = TensorShape.getAbsoluteDimension(axis1, rearrange.length);
        axis2 = TensorShape.getAbsoluteDimension(axis2, rearrange.length);
        int temp = rearrange[axis1];
        rearrange[axis1] = axis2;
        rearrange[axis2] = temp;
        return this.permute(rearrange);
    }

    public T permute(int ... var1);

    public T broadcast(long ... var1);

    public FlattenedView<N> getFlattenedView();

    default public List<N> asFlatList() {
        return Arrays.asList(this.asFlatArray());
    }

    default public boolean isLengthOne() {
        return this.getLength() == 1L;
    }

    default public boolean isScalar() {
        return this.getRank() == 0;
    }

    default public boolean isVector() {
        return this.getRank() == 1;
    }

    default public boolean isMatrix() {
        return this.getRank() == 2;
    }

    default public boolean hasSameShapeAs(Tensor that) {
        return this.hasSameShapeAs(that.getShape());
    }

    default public boolean hasSameShapeAs(long[] shape) {
        return Arrays.equals(this.getShape(), shape);
    }

    default public BooleanTensor elementwiseEquals(Tensor that) {
        return Tensor.elementwiseEquals(this, that);
    }

    public BooleanTensor elementwiseEquals(N var1);

    public static interface FlattenedView<N> {
        public long size();

        public N get(long var1);

        public N getOrScalar(long var1);

        public void set(long var1, N var3);
    }
}

