/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor.ndj4;

import com.google.common.primitives.Longs;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.jvm.Slicer;
import io.improbable.keanu.tensor.ndj4.INDArrayExtensions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public abstract class Nd4jTensor<T, TENSOR extends Tensor<T, TENSOR>>
implements Tensor<T, TENSOR> {
    protected INDArray tensor;

    public Nd4jTensor(INDArray tensor) {
        this.tensor = tensor;
    }

    @Override
    public int getRank() {
        return this.tensor.rank();
    }

    @Override
    public long[] getShape() {
        return this.tensor.shape();
    }

    @Override
    public long[] getStride() {
        return this.tensor.stride();
    }

    @Override
    public long getLength() {
        return this.tensor.length();
    }

    @Override
    public TENSOR get(BooleanTensor booleanIndex) {
        ArrayList<Long> indices = new ArrayList<Long>();
        Tensor.FlattenedView flattenedView = booleanIndex.getFlattenedView();
        for (long i = 0L; i < booleanIndex.getLength(); ++i) {
            if (!((Boolean)flattenedView.get(i)).booleanValue()) continue;
            indices.add(i);
        }
        if (indices.isEmpty()) {
            return this.create(Nd4j.empty((DataType)this.tensor.dataType()));
        }
        INDArray result = this.tensor.reshape(new long[]{this.tensor.length()}).get(new INDArrayIndex[]{NDArrayIndex.indices((long[])Longs.toArray(indices))});
        return this.create(result);
    }

    @Override
    public TENSOR reshape(long ... newShape) {
        return this.create(this.tensor.reshape(newShape));
    }

    @Override
    public TENSOR broadcast(long ... toShape) {
        return this.create(this.tensor.broadcast(toShape));
    }

    @Override
    public TENSOR permute(int ... rearrange) {
        return this.create(this.tensor.permute(rearrange));
    }

    @Override
    public TENSOR diag() {
        return this.create(Nd4j.diag((INDArray)this.tensor));
    }

    @Override
    public TENSOR duplicate() {
        return this.create(this.tensor.dup());
    }

    @Override
    public TENSOR transpose() {
        if (this.getRank() != 2) {
            throw new IllegalArgumentException("Cannot transpose rank " + this.getRank() + " tensor. Try permute instead.");
        }
        return this.create(this.tensor.transpose());
    }

    @Override
    public TENSOR slice(int dimension, long index) {
        return this.create(this.tensor.slice(index, dimension));
    }

    @Override
    public TENSOR slice(Slicer slicer) {
        List<Slicer.StartStopStep> slices = slicer.getSlices();
        INDArrayIndex[] indArrayIndices = new INDArrayIndex[slices.size()];
        for (int i = 0; i < indArrayIndices.length; ++i) {
            Slicer.StartStopStep s = slices.get(i);
            long stop = s.getStop();
            if (stop == -2L) {
                indArrayIndices[i] = NDArrayIndex.point((long)s.getStart());
                continue;
            }
            if (stop == -1L) {
                stop = this.tensor.shape()[i];
            }
            indArrayIndices[i] = NDArrayIndex.interval((long)s.getStart(), (long)s.getStep(), (long)stop);
        }
        return this.create(this.tensor.get(indArrayIndices));
    }

    @Override
    public TENSOR take(long ... index) {
        return this.create(this.tensor.getScalar(index));
    }

    @Override
    public List<TENSOR> split(int dimension, long ... splitAtIndices) {
        List<INDArray> splitINDArrays = INDArrayExtensions.split(this.tensor, dimension, splitAtIndices);
        return splitINDArrays.stream().map(this::create).collect(Collectors.toList());
    }

    public int hashCode() {
        return this.tensor.hashCode();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o instanceof Nd4jTensor) {
            return this.tensor.equals(((Nd4jTensor)o).getTensor());
        }
        if (o instanceof Tensor) {
            Tensor that = (Tensor)o;
            if (!Arrays.equals(that.getShape(), this.getShape())) {
                return false;
            }
            return Arrays.equals(that.asFlatArray(), this.asFlatArray());
        }
        return false;
    }

    public String toString() {
        return this.tensor.toString();
    }

    public INDArray getTensor() {
        return this.tensor;
    }

    protected abstract INDArray getTensor(Tensor<T, ?> var1);

    protected abstract TENSOR create(INDArray var1);

    protected abstract TENSOR set(INDArray var1);

    protected abstract TENSOR getThis();
}

