/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.util;

import java.util.ArrayList;
import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.ArrayUtil;

public class NDArrayUtil {
    public static INDArray exp(INDArray toExp) {
        return NDArrayUtil.expi(toExp.dup());
    }

    public static INDArray expi(INDArray toExp) {
        INDArray flattened = toExp.ravel();
        for (int i = 0; i < flattened.length(); ++i) {
            flattened.put(i, Nd4j.scalar(Math.exp((Double)flattened.getScalar(i).element())));
        }
        return flattened.reshape(toExp.shape());
    }

    public static INDArray center(INDArray arr, int[] shape) {
        INDArray shapeMatrix = ArrayUtil.toNDArray(shape);
        INDArray currShape = ArrayUtil.toNDArray(arr.shape());
        INDArray centered = arr;
        INDArray startIndex = currShape.sub(shapeMatrix).div(2);
        INDArray endIndex = startIndex.add(shapeMatrix);
        arr = centered.get(NDArrayIndex.interval((int)startIndex.get(0), (int)startIndex.get(0)), NDArrayIndex.interval((int)startIndex.get(1), (int)endIndex.get(1)));
        return arr;
    }

    public static INDArray truncate(INDArray nd, int n, int dimension) {
        if (nd.isVector()) {
            INDArray truncated = Nd4j.create(new int[]{n});
            for (int i = 0; i < n; ++i) {
                truncated.put(i, nd.getScalar(i));
            }
            return truncated;
        }
        if (nd.size(dimension) > n) {
            int[] targetShape = ArrayUtil.copy(nd.shape());
            targetShape[dimension] = n;
            int numRequired = ArrayUtil.prod(targetShape);
            if (nd.isVector()) {
                INDArray ret = Nd4j.create(targetShape);
                int count = 0;
                for (int i = 0; i < nd.length(); i += nd.stride()[dimension]) {
                    ret.put(count++, nd.getScalar(i));
                }
                return ret;
            }
            if (nd.isMatrix()) {
                ArrayList<Double> list = new ArrayList<Double>();
                if (dimension == 0) {
                    for (int i = 0; i < nd.rows(); ++i) {
                        INDArray row = nd.getRow(i);
                        for (int j = 0; j < row.length(); ++j) {
                            if (list.size() == numRequired) {
                                return Nd4j.create(ArrayUtil.toArrayDouble(list), targetShape);
                            }
                            list.add((Double)row.getScalar(j).element());
                        }
                    }
                } else if (dimension == 1) {
                    for (int i = 0; i < nd.columns(); ++i) {
                        INDArray row = nd.getColumn(i);
                        for (int j = 0; j < row.length(); ++j) {
                            if (list.size() == numRequired) {
                                return Nd4j.create(ArrayUtil.toArrayDouble(list), targetShape);
                            }
                            list.add((Double)row.getScalar(j).element());
                        }
                    }
                } else {
                    throw new IllegalArgumentException("Illegal dimension for matrix " + dimension);
                }
                return Nd4j.create(ArrayUtil.toArrayDouble(list), targetShape);
            }
            if (dimension == 0) {
                ArrayList<INDArray> slices = new ArrayList<INDArray>();
                for (int i = 0; i < n; ++i) {
                    INDArray slice = nd.slice(i);
                    slices.add(slice);
                }
                return Nd4j.create(slices, targetShape);
            }
            ArrayList<Double> list = new ArrayList<Double>();
            int numElementsPerSlice = ArrayUtil.prod(ArrayUtil.removeIndex(targetShape, 0));
            for (int i = 0; i < nd.slices(); ++i) {
                INDArray slice = nd.slice(i).ravel();
                for (int j = 0; j < numElementsPerSlice; ++j) {
                    list.add((Double)slice.getScalar(j).element());
                }
            }
            assert (list.size() == ArrayUtil.prod(targetShape)) : "Illegal shape for length " + list.size();
            return Nd4j.create(ArrayUtil.toArrayDouble(list), targetShape);
        }
        return nd;
    }

    public static INDArray padWithZeros(INDArray nd, int[] targetShape) {
        if (Arrays.equals(nd.shape(), targetShape)) {
            return nd;
        }
        if (ArrayUtil.prod(nd.shape()) >= ArrayUtil.prod(targetShape)) {
            return nd;
        }
        INDArray ret = Nd4j.create(targetShape);
        System.arraycopy(nd.data(), 0, ret.data(), 0, nd.data().length);
        return ret;
    }

    private static boolean isRowOp(MatrixOp op) {
        return op == MatrixOp.ROW_MIN || op == MatrixOp.ROW_MAX || op == MatrixOp.ROW_SUM || op == MatrixOp.ROW_MEAN;
    }

    private static boolean isColumnOp(MatrixOp op) {
        return op == MatrixOp.COLUMN_MIN || op == MatrixOp.COLUMN_MAX || op == MatrixOp.COLUMN_SUM || op == MatrixOp.COLUMN_MEAN;
    }

    public static INDArray doSliceWise(ScalarOp op, INDArray arr) {
        INDArray cast = arr = arr.reshape(new int[]{1, arr.length()});
        if (op == ScalarOp.NORM_1) {
            return Nd4j.scalar(Nd4j.getBlasWrapper().asum(cast));
        }
        if (op == ScalarOp.NORM_2) {
            return Nd4j.scalar(Nd4j.getBlasWrapper().nrm2(cast));
        }
        if (op == ScalarOp.NORM_MAX) {
            int i = Nd4j.getBlasWrapper().iamax(cast);
            return arr.getScalar(i);
        }
        INDArray s = Nd4j.scalar(0.0f);
        block7: for (int i = 0; i < arr.length(); ++i) {
            switch (op) {
                case SUM: {
                    s.addi(arr.getScalar(i));
                    continue block7;
                }
                case MEAN: {
                    s.addi(arr.getScalar(i));
                    continue block7;
                }
                case MAX: {
                    double curr = (Double)arr.getScalar(i).element();
                    double sEle1 = (Double)arr.getScalar(i).element();
                    if (!(curr > sEle1)) continue block7;
                    s = arr.getScalar(i);
                    continue block7;
                }
                case MIN: {
                    double curr2 = (Double)arr.getScalar(i).element();
                    double sEle2 = (Double)arr.getScalar(i).element();
                    if (!(curr2 < sEle2)) continue block7;
                    s = arr.getScalar(i);
                    continue block7;
                }
                case PROD: {
                    s.muli(arr.getScalar(i));
                }
            }
        }
        if (op == ScalarOp.MEAN) {
            s.divi(Nd4j.scalar(arr.length()));
        }
        return s;
    }

    public static enum MatrixOp {
        COLUMN_MIN,
        COLUMN_MAX,
        COLUMN_SUM,
        COLUMN_MEAN,
        ROW_MIN,
        ROW_MAX,
        ROW_SUM,
        ROW_MEAN;

    }

    public static enum DimensionOp {
        SUM,
        MEAN,
        PROD,
        MAX,
        MIN,
        ARG_MIN,
        NORM_2,
        NORM_1,
        NORM_MAX,
        FFT;

    }

    public static enum ScalarOp {
        SUM,
        MEAN,
        PROD,
        MAX,
        MIN,
        ARG_MAX,
        ARG_MIN,
        NORM_2,
        NORM_1,
        NORM_MAX;

    }
}

