/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.ops.transforms;

import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.ArrayOps;
import org.nd4j.linalg.ops.BaseElementWiseOp;
import org.nd4j.linalg.ops.ElementWiseOp;
import org.nd4j.linalg.ops.transforms.Abs;
import org.nd4j.linalg.ops.transforms.EqualTo;
import org.nd4j.linalg.ops.transforms.Exp;
import org.nd4j.linalg.ops.transforms.Floor;
import org.nd4j.linalg.ops.transforms.GreaterThan;
import org.nd4j.linalg.ops.transforms.HardTanh;
import org.nd4j.linalg.ops.transforms.Identity;
import org.nd4j.linalg.ops.transforms.LessThan;
import org.nd4j.linalg.ops.transforms.Log;
import org.nd4j.linalg.ops.transforms.Max;
import org.nd4j.linalg.ops.transforms.Negative;
import org.nd4j.linalg.ops.transforms.Pow;
import org.nd4j.linalg.ops.transforms.Round;
import org.nd4j.linalg.ops.transforms.Sigmoid;
import org.nd4j.linalg.ops.transforms.Sign;
import org.nd4j.linalg.ops.transforms.Sqrt;
import org.nd4j.linalg.ops.transforms.Stabilize;
import org.nd4j.linalg.ops.transforms.Tanh;
import org.nd4j.linalg.util.ArrayUtil;

public class Transforms {
    public static INDArray neg(INDArray ndArray) {
        return Transforms.exec(ndArray, Negative.class, null);
    }

    public static IComplexNDArray neg(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, Negative.class, null);
    }

    public static INDArray maxPool(INDArray input, int[] ds, boolean ignoreBorder) {
        assert (input.length() >= 2) : "Max pooling requires an ndarray of >= length 2";
        assert (ds.length == 2) : "Down sampling must be of length 2 (the factors used for each image size";
        assert (input.shape().length == 4) : "Only supports 4 dimensional tensors";
        int batchSize = ArrayUtil.prod(new int[]{input.size(0) * input.size(1)});
        int rows = input.size(2);
        int cols = input.size(3);
        INDArray signalNDArray = input.reshape(batchSize, 1, rows, cols);
        INDArray zz = Nd4j.create(signalNDArray.shape()).assign(Float.valueOf(Float.MIN_VALUE));
        int rowIter = ignoreBorder ? rows / (int)Math.pow(ds[0], 2.0) : rows;
        int colIter = ignoreBorder ? cols / (int)Math.pow(ds[1], 2.0) : cols;
        for (int i = 0; i < signalNDArray.size(0); ++i) {
            for (int j = 0; j < signalNDArray.size(1); ++j) {
                for (int k = 0; k < rowIter; ++k) {
                    int zk = k / ds[0];
                    int l = 0;
                    while (l < colIter) {
                        int zl = l / ds[1];
                        float num = input.get(new int[]{i, j, k, l++});
                        float zzGet = zz.get(new int[]{i, j, zk, zl});
                        zz.putScalar(new int[]{i, j, zk, zl}, (Number)Float.valueOf(Math.max(num, zzGet)));
                    }
                }
            }
        }
        return zz.reshape(signalNDArray.shape());
    }

    public static INDArray downSample(INDArray d1, int[] stride) {
        INDArray d = Nd4j.ones(stride);
        d.divi(ArrayUtil.prod(stride));
        INDArray ret = Convolution.convn(d1, d, Convolution.Type.VALID);
        ret = ret.get(NDArrayIndex.interval(0, stride[0]), NDArrayIndex.interval(0, stride[1]));
        return ret;
    }

    public static INDArray pool(INDArray toPool, int[] stride) {
        int nDims = toPool.shape().length;
        assert (nDims == 3) : "NDArray must have 3 dimensions";
        int nRows = toPool.shape()[nDims - 2];
        int nCols = toPool.shape()[nDims - 1];
        int yStride = stride[0];
        int xStride = stride[1];
        INDArray blocks = Nd4j.create(toPool.shape());
        int iR = 0;
        while ((double)iR < Math.ceil(nRows / yStride)) {
            NDArrayIndex rows = NDArrayIndex.interval(iR * yStride, iR * yStride, true);
            int jC = 0;
            while ((double)jC < Math.ceil(nCols / xStride)) {
                NDArrayIndex cols = NDArrayIndex.interval(jC * xStride, jC * xStride + 1, true);
                INDArray blockVal = toPool.get(rows, cols).sum(toPool.shape().length - 1).sum(toPool.shape().length - 1);
                blocks.put(new NDArrayIndex[]{rows, cols}, blockVal.permute(1, 2, 0)).repmat(new int[]{rows.length(), cols.length()});
                ++jC;
            }
            ++iR;
        }
        return blocks;
    }

    public static INDArray upSample(INDArray d, INDArray scale) {
        INDArray idx = Nd4j.create(d.shape().length, 1);
        for (int i = 0; i < d.shape().length; ++i) {
            INDArray tmp = Nd4j.zeros(d.size(i) * (int)scale.get(i), 1);
            int[] indices = ArrayUtil.range(0, (int)scale.get(i) * d.size(i), (int)scale.get(i));
            tmp.putScalar(indices, (Number)Float.valueOf(1.0f));
            idx.put(i, tmp.cumsum(Integer.MAX_VALUE).sum(Integer.MAX_VALUE));
        }
        return idx;
    }

    public static double cosineSim(INDArray d1, INDArray d2) {
        d1 = Transforms.unitVec(d1.dup());
        d2 = Transforms.unitVec(d2.dup());
        double ret = Nd4j.getBlasWrapper().dot(d1, d2);
        return ret;
    }

    public static INDArray normalizeZeroMeanAndUnitVariance(INDArray toNormalize) {
        INDArray columnMeans = toNormalize.mean(0);
        INDArray columnStds = toNormalize.std(0);
        toNormalize.subiRowVector(columnMeans);
        columnStds.addi(Float.valueOf(Nd4j.EPS_THRESHOLD));
        toNormalize.diviRowVector(columnStds);
        return toNormalize;
    }

    public static INDArray unitVec(INDArray toScale) {
        float length = ((Float)toScale.norm2(Integer.MAX_VALUE).element()).floatValue();
        if (length > 0.0f) {
            return Nd4j.getBlasWrapper().scal(1.0f / length, toScale);
        }
        return toScale;
    }

    public static INDArray eq(INDArray ndArray) {
        return Transforms.exec(ndArray, EqualTo.class, null);
    }

    public static IComplexNDArray eq(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, EqualTo.class, null);
    }

    public static INDArray floor(INDArray ndArray) {
        return Transforms.exec(ndArray, Floor.class, null);
    }

    public static INDArray sign(IComplexNDArray toSign) {
        return Transforms.exec(toSign, Sign.class, null);
    }

    public static INDArray sign(INDArray toSign) {
        return Transforms.exec(toSign, Sign.class, null);
    }

    public static IComplexNDArray floor(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, Floor.class, null);
    }

    public static INDArray gt(INDArray ndArray) {
        return Transforms.exec(ndArray, GreaterThan.class, null);
    }

    public static IComplexNDArray gt(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, GreaterThan.class, null);
    }

    public static INDArray lt(INDArray ndArray) {
        return Transforms.exec(ndArray, LessThan.class, null);
    }

    public static IComplexNDArray lt(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, LessThan.class, null);
    }

    public static INDArray stabilize(INDArray ndArray, float k) {
        return Transforms.exec(ndArray, Stabilize.class, new Object[]{Float.valueOf(k)});
    }

    public static IComplexNDArray stabilize(IComplexNDArray ndArray, float k) {
        return Transforms.exec(ndArray, Stabilize.class, new Object[]{Float.valueOf(k)});
    }

    public static INDArray abs(INDArray ndArray) {
        return Transforms.exec(ndArray, Abs.class, null);
    }

    public static IComplexNDArray abs(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, Abs.class, null);
    }

    public static INDArray exp(INDArray ndArray) {
        return Transforms.exec(ndArray, Exp.class, null);
    }

    public static IComplexNDArray exp(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, Exp.class, null);
    }

    public static INDArray hardTanh(INDArray ndArray) {
        return Transforms.exec(ndArray, HardTanh.class, null);
    }

    public static IComplexNDArray hardTanh(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, HardTanh.class, null);
    }

    public static INDArray identity(INDArray ndArray) {
        return Transforms.exec(ndArray, Identity.class, null);
    }

    public static IComplexNDArray identity(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, Identity.class, null);
    }

    public static INDArray max(INDArray ndArray) {
        return Transforms.exec(ndArray, Max.class, null);
    }

    public static INDArray max(INDArray ndArray, float max) {
        return Transforms.exec(ndArray, Max.class, new Object[]{Float.valueOf(max)});
    }

    public static IComplexNDArray max(IComplexNDArray ndArray, float max) {
        return Transforms.exec(ndArray, Max.class, null);
    }

    public static IComplexNDArray max(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, Max.class, null);
    }

    public static INDArray pow(INDArray ndArray, Number power) {
        return Transforms.exec(ndArray, Pow.class, new Object[]{power});
    }

    public static IComplexNDArray pow(IComplexNDArray ndArray, IComplexNumber power) {
        return Transforms.exec(ndArray, Pow.class, new Object[]{power});
    }

    public static INDArray round(INDArray ndArray) {
        return Transforms.exec(ndArray, Round.class, null);
    }

    public static IComplexNDArray round(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, Round.class, null);
    }

    public static INDArray sigmoid(INDArray ndArray) {
        return Transforms.exec(ndArray, Sigmoid.class, null);
    }

    public static IComplexNDArray sigmoid(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, Sigmoid.class, null);
    }

    public static INDArray sqrt(INDArray ndArray) {
        return Transforms.exec(ndArray, Sqrt.class, null);
    }

    public static IComplexNDArray sqrt(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, Sqrt.class, null);
    }

    public static INDArray tanh(INDArray ndArray) {
        return Transforms.exec(ndArray, Tanh.class, null);
    }

    public static IComplexNDArray tanh(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, Tanh.class, null);
    }

    public static INDArray log(INDArray ndArray) {
        return Transforms.exec(ndArray, Log.class, null);
    }

    public static IComplexNDArray log(IComplexNDArray ndArray) {
        return Transforms.exec(ndArray, Log.class, null);
    }

    private static INDArray exec(INDArray indArray, Class<? extends BaseElementWiseOp> clazz, Object[] extraArgs) {
        ElementWiseOp ops = new ArrayOps().from(indArray.dup()).op(clazz).extraArgs(extraArgs).build();
        ops.exec();
        return ops.from();
    }

    private static IComplexNDArray exec(IComplexNDArray indArray, Class<? extends BaseElementWiseOp> clazz, Object[] extraArgs) {
        ElementWiseOp ops = new ArrayOps().from(indArray.dup()).op(clazz).extraArgs(extraArgs).build();
        ops.exec();
        IComplexNDArray n = (IComplexNDArray)ops.from();
        return n;
    }
}

