/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor.generic;

import com.google.common.primitives.Ints;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.buffer.JVMBuffer;
import io.improbable.keanu.tensor.generic.GenericBuffer;
import io.improbable.keanu.tensor.jvm.JVMTensor;
import io.improbable.keanu.tensor.jvm.JVMTensorBroadcast;
import io.improbable.keanu.tensor.jvm.ResultWrapper;
import java.util.Arrays;
import java.util.function.BiFunction;

public class GenericTensor<T>
extends JVMTensor<T, GenericTensor<T>, GenericBuffer.PrimitiveGenericWrapper<T>>
implements Tensor<T, GenericTensor<T>> {
    private static final GenericBuffer.GenericArrayWrapperFactory factory = new GenericBuffer.GenericArrayWrapperFactory();

    public static <T> GenericTensor<T> createFilled(T data, long[] shape) {
        return new GenericTensor<T>(GenericTensor.fillArray(shape, data), shape);
    }

    public static <T> GenericTensor<T> create(T ... data) {
        return GenericTensor.create(data, new long[]{data.length});
    }

    public static <T> GenericTensor<T> create(T[] data, long[] shape) {
        if (TensorShape.getLength(shape) != (long)data.length) {
            throw new IllegalArgumentException("Shape size does not match data length");
        }
        return new GenericTensor<T>(data, shape);
    }

    public static <T> GenericTensor<T> scalar(T data) {
        return new GenericTensor<T>(data);
    }

    private GenericTensor(GenericBuffer.PrimitiveGenericWrapper<T> buffer, long[] shape, long[] stride) {
        super(buffer, shape, stride);
    }

    private GenericTensor(GenericBuffer.PrimitiveGenericWrapper<T> buffer, long[] shape) {
        super(buffer, shape, TensorShape.getRowFirstStride(shape));
    }

    private GenericTensor(T[] buffer, long[] shape, long[] stride) {
        super(factory.create(buffer), shape, stride);
    }

    private GenericTensor(T[] buffer, long[] shape) {
        this(buffer, shape, TensorShape.getRowFirstStride(shape));
    }

    private GenericTensor(T scalar) {
        super(new GenericBuffer.GenericWrapper<T>(scalar), new long[0], new long[0]);
    }

    private static <T> T[] fillArray(long[] shape, T value) {
        Object[] data = new Object[TensorShape.getLengthAsInt(shape)];
        Arrays.fill(data, value);
        return data;
    }

    @Override
    public GenericTensor<T> duplicate() {
        return new GenericTensor<T>((GenericBuffer.PrimitiveGenericWrapper)((GenericBuffer.PrimitiveGenericWrapper)this.buffer).copy(), Arrays.copyOf(this.shape, this.shape.length), Arrays.copyOf(this.stride, this.stride.length));
    }

    @Override
    public Tensor.FlattenedView<T> getFlattenedView() {
        return new BaseSimpleFlattenedView(((GenericBuffer.PrimitiveGenericWrapper)this.buffer).asArray());
    }

    @Override
    public BooleanTensor elementwiseEquals(T value) {
        boolean[] result = new boolean[Ints.checkedCast((long)((GenericBuffer.PrimitiveGenericWrapper)this.buffer).getLength())];
        int i = 0;
        while ((long)i < ((GenericBuffer.PrimitiveGenericWrapper)this.buffer).getLength()) {
            result[i] = ((GenericBuffer.PrimitiveGenericWrapper)this.buffer).get(i).equals(value);
            ++i;
        }
        return BooleanTensor.create(result, this.shape);
    }

    @Override
    public T[] asFlatArray() {
        return ((GenericBuffer.PrimitiveGenericWrapper)((GenericBuffer.PrimitiveGenericWrapper)this.buffer).copy()).asArray();
    }

    @Override
    protected GenericTensor<T> create(GenericBuffer.PrimitiveGenericWrapper<T> buffer, long[] shape, long[] stride) {
        return new GenericTensor<T>(buffer, shape, stride);
    }

    @Override
    protected GenericTensor<T> set(GenericBuffer.PrimitiveGenericWrapper<T> buffer, long[] shape, long[] stride) {
        this.buffer = buffer;
        this.shape = shape;
        this.stride = stride;
        return this;
    }

    @Override
    protected JVMBuffer.ArrayWrapperFactory<T, GenericBuffer.PrimitiveGenericWrapper<T>> getFactory() {
        return factory;
    }

    @Override
    public GenericTensor<T> take(long ... index) {
        return GenericTensor.scalar(this.getValue(index));
    }

    private static <T> GenericTensor<T> getRawBufferIfJVMTensor(Tensor<T, ?> tensor) {
        if (tensor instanceof GenericTensor) {
            return (GenericTensor)tensor;
        }
        return new GenericTensor<T>(factory.create(tensor.asFlatArray()), tensor.getShape(), tensor.getStride());
    }

    public <R> GenericTensor<R> apply(Tensor<T, ?> right, BiFunction<T, T, R> op) {
        GenericTensor<T> rightTensor = GenericTensor.getRawBufferIfJVMTensor(right);
        ResultWrapper result = JVMTensorBroadcast.broadcastIfNeeded(factory, this.buffer, this.shape, this.stride, ((GenericBuffer.PrimitiveGenericWrapper)this.buffer).getLength(), rightTensor.buffer, rightTensor.shape, rightTensor.stride, ((GenericBuffer.PrimitiveGenericWrapper)rightTensor.buffer).getLength(), op, false);
        return new GenericTensor<T>((GenericBuffer.PrimitiveGenericWrapper)result.outputBuffer, result.outputShape, result.outputStride);
    }

    private static class BaseSimpleFlattenedView<T>
    implements Tensor.FlattenedView<T> {
        T[] data;

        public BaseSimpleFlattenedView(T[] data) {
            this.data = data;
        }

        @Override
        public long size() {
            return this.data.length;
        }

        @Override
        public T get(long index) {
            if (index > Integer.MAX_VALUE) {
                throw new IllegalArgumentException("Only integer based indexing supported for generic tensors");
            }
            return this.data[(int)index];
        }

        @Override
        public T getOrScalar(long index) {
            if (this.data.length == 1) {
                return this.get(0L);
            }
            return this.get(index);
        }

        @Override
        public void set(long index, T value) {
            if (index > Integer.MAX_VALUE) {
                throw new IllegalArgumentException("Only integer based indexing supported for generic tensors");
            }
            this.data[(int)index] = value;
        }
    }
}

