/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor.jvm;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.buffer.JVMBuffer;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.tensor.jvm.DimensionIndexMapper;
import io.improbable.keanu.tensor.jvm.IndexMapper;
import io.improbable.keanu.tensor.jvm.JVMTensorBroadcast;
import io.improbable.keanu.tensor.jvm.ResultWrapper;
import io.improbable.keanu.tensor.jvm.Slicer;
import io.improbable.keanu.tensor.jvm.SlicerIndexMapper;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;

public abstract class JVMTensor<T, TENSOR extends Tensor<T, TENSOR>, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>>
implements Tensor<T, TENSOR> {
    protected B buffer;
    protected long[] shape;
    protected long[] stride;

    protected JVMTensor(B buffer, long[] shape, long[] stride) {
        this.buffer = buffer;
        this.shape = shape;
        this.stride = stride;
    }

    @Override
    public int getRank() {
        return this.shape.length;
    }

    @Override
    public long[] getShape() {
        return Arrays.copyOf(this.shape, this.shape.length);
    }

    @Override
    public long[] getStride() {
        return Arrays.copyOf(this.stride, this.stride.length);
    }

    @Override
    public long getLength() {
        return this.buffer.getLength();
    }

    @Override
    public TENSOR get(BooleanTensor booleanIndex) {
        ArrayList<Long> indices = new ArrayList<Long>();
        Tensor.FlattenedView flattenedView = booleanIndex.getFlattenedView();
        for (long i = 0L; i < booleanIndex.getLength(); ++i) {
            if (!((Boolean)flattenedView.get(i)).booleanValue()) continue;
            indices.add(i);
        }
        B newBuffer = this.getFactory().createNew((T)indices.size());
        int i = 0;
        while ((long)i < newBuffer.getLength()) {
            newBuffer.set(this.buffer.get((Long)indices.get(i)), i);
            ++i;
        }
        return this.create(newBuffer, new long[]{newBuffer.getLength()}, new long[]{1L});
    }

    @Override
    public TENSOR diag() {
        return this.createFromResultWrapper(JVMTensor.diag(this.shape.length, this.shape, this.buffer, this.getFactory()));
    }

    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> ResultWrapper<T, B> diag(int rank, long[] shape, B buffer, JVMBuffer.ArrayWrapperFactory<T, B> factory) {
        long[] newShape;
        B newBuffer;
        if (rank == 1) {
            long n = buffer.getLength();
            newBuffer = factory.createNew((T)(n * n));
            for (long i = 0L; i < n; ++i) {
                newBuffer.set(buffer.get(i), i * n + i);
            }
            newShape = new long[]{n, n};
        } else if (rank == 2 && shape[0] == shape[1]) {
            long n = shape[0];
            newBuffer = factory.createNew((T)n);
            for (long i = 0L; i < n; ++i) {
                newBuffer.set(buffer.get(i * n + i), i);
            }
            newShape = new long[]{n};
        } else {
            throw new IllegalArgumentException("Diag is only valid for vectors or square matrices");
        }
        return new ResultWrapper(newBuffer, newShape, TensorShape.getRowFirstStride(newShape));
    }

    @Override
    public TENSOR permute(int ... rearrange) {
        return this.createFromResultWrapper(JVMTensor.permute(this.getFactory(), this.buffer, this.shape, this.stride, rearrange));
    }

    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> ResultWrapper<T, B> permute(JVMBuffer.ArrayWrapperFactory<T, B> factory, B buffer, long[] shape, long[] stride, int ... rearrange) {
        Preconditions.checkArgument((rearrange.length == shape.length ? 1 : 0) != 0);
        long[] resultShape = TensorShape.getPermutedIndices(shape, rearrange);
        long[] resultStride = TensorShape.getRowFirstStride(resultShape);
        B newBuffer = factory.createNew((T)buffer.getLength());
        for (long flatIndex = 0L; flatIndex < buffer.getLength(); ++flatIndex) {
            long permutedFlatIndex = TensorShape.convertFromFlatIndexToPermutedFlatIndex(flatIndex, shape, stride, resultShape, resultStride, rearrange);
            newBuffer.set(buffer.get(flatIndex), permutedFlatIndex);
        }
        return new ResultWrapper(newBuffer, resultShape, resultStride);
    }

    @Override
    public List<TENSOR> split(int dimension, long ... splitAtIndices) {
        return JVMTensor.split(this.getFactory(), this.buffer, this.shape, this.stride, dimension, splitAtIndices).stream().map(this::createFromResultWrapper).collect(Collectors.toList());
    }

    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> List<ResultWrapper<T, B>> split(JVMBuffer.ArrayWrapperFactory<T, B> factory, B fromBuffer, long[] shape, long[] stride, int dimension, long ... splitAtIndices) {
        if ((dimension = TensorShape.getAbsoluteDimension(dimension, shape.length)) < 0 || dimension >= shape.length) {
            throw new IllegalArgumentException("Invalid dimension to split on " + dimension);
        }
        int[] moveDimToZero = TensorShape.slideDimension(dimension, 0, shape.length);
        int[] moveZeroToDim = TensorShape.slideDimension(0, dimension, shape.length);
        ResultWrapper<T, B> rawBuffer = JVMTensor.permute(factory, fromBuffer, shape, stride, moveDimToZero);
        ArrayList<ResultWrapper<T, B>> splitTensor = new ArrayList<ResultWrapper<T, B>>();
        long previousSplitAtIndex = 0L;
        long rawBufferPosition = 0L;
        for (long splitAtIndex : splitAtIndices) {
            long[] subTensorShape = Arrays.copyOf(shape, shape.length);
            long subTensorLengthInDimension = splitAtIndex - previousSplitAtIndex;
            if (subTensorLengthInDimension > shape[dimension] || subTensorLengthInDimension <= 0L) {
                throw new IllegalArgumentException("Invalid index to split on " + splitAtIndex + " at " + dimension + " for tensor of shape " + Arrays.toString(shape));
            }
            subTensorShape[dimension] = subTensorLengthInDimension;
            long subTensorLength = TensorShape.getLength(subTensorShape);
            B buffer = factory.createNew((T)subTensorLength);
            buffer.copyFrom(rawBuffer.outputBuffer, rawBufferPosition, 0L, subTensorLength);
            long[] subTensorPermutedShape = TensorShape.getPermutedIndices(subTensorShape, moveDimToZero);
            ResultWrapper<T, B> result = JVMTensor.permute(factory, buffer, subTensorPermutedShape, TensorShape.getRowFirstStride(subTensorPermutedShape), moveZeroToDim);
            splitTensor.add(result);
            previousSplitAtIndex = splitAtIndex;
            rawBufferPosition += buffer.getLength();
        }
        return splitTensor;
    }

    @Override
    public TENSOR slice(int dimension, long index) {
        return this.createFromResultWrapper(JVMTensor.slice(this.getFactory(), this.buffer, new DimensionIndexMapper(this.shape, this.stride, dimension, index)));
    }

    @Override
    public TENSOR slice(Slicer slicer) {
        return this.createFromResultWrapper(JVMTensor.slice(this.getFactory(), this.buffer, new SlicerIndexMapper(slicer, this.shape, this.stride)));
    }

    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> ResultWrapper<T, B> slice(JVMBuffer.ArrayWrapperFactory<T, B> factory, B buffer, IndexMapper indexMapper) {
        long[] resultShape = indexMapper.getResultShape();
        long[] resultStride = indexMapper.getResultStride();
        B newBuffer = factory.createNew((T)TensorShape.getLength(resultShape));
        for (long i = 0L; i < newBuffer.getLength(); ++i) {
            long j = indexMapper.getSourceIndexFromResultIndex(i);
            newBuffer.set(buffer.get(j), i);
        }
        return new ResultWrapper(newBuffer, resultShape, resultStride);
    }

    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> ResultWrapper<T, B> concat(JVMBuffer.ArrayWrapperFactory<T, B> factory, Tensor[] tensors, int dimension, List<B> toConcat) {
        boolean shouldRearrange;
        long[] concatShape = TensorShape.getConcatResultShape(dimension, tensors);
        boolean bl = shouldRearrange = dimension != 0;
        if (shouldRearrange) {
            int[] rearrange = TensorShape.getPermutationForDimensionToDimensionZero(dimension, concatShape);
            ArrayList toConcatOnDimensionZero = new ArrayList();
            for (int i = 0; i < toConcat.size(); ++i) {
                toConcatOnDimensionZero.add(JVMTensor.permute(factory, (JVMBuffer.PrimitiveArrayWrapper)toConcat.get((int)i), (long[])tensors[i].getShape(), (long[])tensors[i].getStride(), (int[])rearrange).outputBuffer);
            }
            long[] permutedConcatShape = TensorShape.getPermutedIndices(concatShape, rearrange);
            B concatOnDimZero = JVMTensor.concatOnDimensionZero(factory, permutedConcatShape, toConcatOnDimensionZero);
            return JVMTensor.permute(factory, concatOnDimZero, permutedConcatShape, TensorShape.getRowFirstStride(permutedConcatShape), TensorShape.invertedPermute(rearrange));
        }
        B buffer = JVMTensor.concatOnDimensionZero(factory, concatShape, toConcat);
        return new ResultWrapper(buffer, concatShape, TensorShape.getRowFirstStride(concatShape));
    }

    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> B concatOnDimensionZero(JVMBuffer.ArrayWrapperFactory<T, B> factory, long[] concatShape, List<B> toConcat) {
        B concatBuffer = factory.createNew((T)TensorShape.getLength(concatShape));
        int bufferPosition = 0;
        for (int i = 0; i < toConcat.size(); ++i) {
            JVMBuffer.PrimitiveArrayWrapper cBuffer = (JVMBuffer.PrimitiveArrayWrapper)toConcat.get(i);
            concatBuffer.copyFrom(cBuffer, 0L, bufferPosition, cBuffer.getLength());
            bufferPosition = (int)((long)bufferPosition + cBuffer.getLength());
        }
        return concatBuffer;
    }

    protected TENSOR broadcastableBinaryOpWithAutoBroadcast(BiFunction<T, T, T> op, JVMTensor<T, TENSOR, B> right) {
        ResultWrapper<T, B> result = JVMTensorBroadcast.broadcastIfNeeded(this.getFactory(), this.buffer, this.shape, this.stride, this.buffer.getLength(), right.buffer, right.shape, right.stride, right.buffer.getLength(), op, true);
        return this.set(result.outputBuffer, result.outputShape, result.outputStride);
    }

    @Override
    public TENSOR reshape(long ... newShape) {
        long[] normalizedShape = TensorShape.getReshapeAllowingWildcard(this.shape, this.buffer.getLength(), newShape);
        return this.create(this.buffer.copy(), normalizedShape, TensorShape.getRowFirstStride(normalizedShape));
    }

    @Override
    public TENSOR broadcast(long ... toShape) {
        long outputLength = TensorShape.getLength(toShape);
        long[] outputStride = TensorShape.getRowFirstStride(toShape);
        B outputBuffer = this.getFactory().createNew((T)outputLength);
        JVMTensorBroadcast.broadcast(this.buffer, this.shape, this.stride, outputBuffer, outputStride);
        return this.create(outputBuffer, toShape, outputStride);
    }

    public IntegerTensor argCompare(BiFunction<T, T, Boolean> compareOp, int axis) {
        return JVMTensor.argCompare(this.getFactory(), this.buffer, compareOp, this.shape, this.stride, axis);
    }

    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> IntegerTensor argCompare(JVMBuffer.ArrayWrapperFactory<T, B> factory, B buffer, BiFunction<T, T, Boolean> compareOp, long[] shape, long[] stride, int axis) {
        if (axis >= shape.length) {
            throw new IllegalArgumentException("Cannot take arg max of axis " + axis + " on a " + shape.length + " rank tensor.");
        }
        int[] rearrange = TensorShape.getPermutationForDimensionToDimensionZero(axis, shape);
        Object permutedBuffer = JVMTensor.permute(factory, buffer, (long[])shape, (long[])stride, (int[])rearrange).outputBuffer;
        int dimLength = (int)(buffer.getLength() / shape[axis]);
        B maxBuffer = factory.createNew((T)dimLength);
        int[] maxIndex = new int[dimLength];
        Arrays.fill(maxIndex, -1);
        int i = 0;
        while ((long)i < permutedBuffer.getLength()) {
            int bufferIndex = i % dimLength;
            T value = permutedBuffer.get(i);
            if (maxIndex[bufferIndex] < 0 || compareOp.apply(value, maxBuffer.get(bufferIndex)).booleanValue()) {
                maxBuffer.set(value, bufferIndex);
                maxIndex[bufferIndex] = i / dimLength;
            }
            ++i;
        }
        return IntegerTensor.create(maxIndex, ArrayUtils.remove((long[])shape, (int)axis));
    }

    public int argCompare(BiFunction<T, T, Boolean> compareOp) {
        return JVMTensor.argCompare(this.buffer, compareOp);
    }

    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> int argCompare(B buffer, BiFunction<T, T, Boolean> compareOp) {
        Object min = null;
        int argMin = -1;
        int i = 0;
        while ((long)i < buffer.getLength()) {
            T value = buffer.get(i);
            if (i == 0 || compareOp.apply(value, min).booleanValue()) {
                min = value;
                argMin = i;
            }
            ++i;
        }
        return argMin;
    }

    public BooleanTensor isApply(Function<T, Boolean> op) {
        boolean[] newBuffer = new boolean[Ints.checkedCast((long)this.buffer.getLength())];
        int i = 0;
        while ((long)i < this.buffer.getLength()) {
            newBuffer[i] = op.apply(this.buffer.get(i));
            ++i;
        }
        return BooleanTensor.create(newBuffer, Arrays.copyOf(this.shape, this.shape.length));
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        JVMTensor that = (JVMTensor)o;
        return Arrays.equals(this.shape, that.shape) && this.buffer.equals(that.buffer);
    }

    public String toString() {
        StringBuilder dataString = new StringBuilder();
        if (this.buffer.getLength() > 20L) {
            dataString.append(Arrays.toString(this.buffer.asArray(0L, 10L)));
            dataString.append("...");
            dataString.append(Arrays.toString(this.buffer.asArray(this.buffer.getLength() - 10L, this.buffer.getLength())));
        } else {
            dataString.append(Arrays.toString(this.buffer.asArray()));
        }
        return "{\nshape = " + Arrays.toString(this.shape) + "\ndata = " + dataString.toString() + "\n}";
    }

    public int hashCode() {
        int result = Arrays.hashCode(this.shape);
        result = 31 * result + this.buffer.hashCode();
        return result;
    }

    private TENSOR createFromResultWrapper(ResultWrapper<T, B> wrapper) {
        return this.create(wrapper.outputBuffer, wrapper.outputShape, wrapper.outputStride);
    }

    protected abstract TENSOR create(B var1, long[] var2, long[] var3);

    protected abstract TENSOR set(B var1, long[] var2, long[] var3);

    protected abstract JVMBuffer.ArrayWrapperFactory<T, B> getFactory();
}

