/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.ops.transforms;

import java.util.ArrayList;
import java.util.Arrays;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.accum.distances.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
import org.nd4j.linalg.api.ops.impl.transforms.Abs;
import org.nd4j.linalg.api.ops.impl.transforms.Ceil;
import org.nd4j.linalg.api.ops.impl.transforms.Exp;
import org.nd4j.linalg.api.ops.impl.transforms.Floor;
import org.nd4j.linalg.api.ops.impl.transforms.HardTanh;
import org.nd4j.linalg.api.ops.impl.transforms.Identity;
import org.nd4j.linalg.api.ops.impl.transforms.Log;
import org.nd4j.linalg.api.ops.impl.transforms.Negative;
import org.nd4j.linalg.api.ops.impl.transforms.Pow;
import org.nd4j.linalg.api.ops.impl.transforms.Round;
import org.nd4j.linalg.api.ops.impl.transforms.Sigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.Sign;
import org.nd4j.linalg.api.ops.impl.transforms.Sqrt;
import org.nd4j.linalg.api.ops.impl.transforms.Stabilize;
import org.nd4j.linalg.api.ops.impl.transforms.Tanh;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.GreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.LessThanOrEqual;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.ArrayUtil;

public class Transforms {
    public static INDArray maxPool(INDArray input, int[] ds, boolean ignoreBorder) {
        assert (input.length() >= 2) : "Max pooling requires an ndarray of >= length 2";
        assert (ds.length == 2) : "Down sampling must be of length 2 (the factors used for each image size";
        assert (input.shape().length == 4) : "Only supports 4 dimensional tensors";
        int batchSize = ArrayUtil.prod(new int[]{input.size(0) * input.size(1)});
        int rows = input.size(2);
        int cols = input.size(3);
        INDArray signalNDArray = input.reshape(batchSize, 1, rows, cols);
        INDArray zz = Nd4j.create(signalNDArray.shape());
        int rowIter = ignoreBorder ? (int)((double)rows / Math.pow(ds[0], 2.0)) : rows;
        int colIter = ignoreBorder ? (int)((double)cols / Math.pow(ds[1], 2.0)) : cols;
        rowIter = Math.max(1, rowIter);
        colIter = Math.max(1, colIter);
        for (int i = 0; i < signalNDArray.size(0); ++i) {
            for (int j = 0; j < signalNDArray.size(1); ++j) {
                for (int k = 0; k < rowIter; ++k) {
                    int zk = k / ds[0];
                    int l = 0;
                    while (l < colIter) {
                        int zl = l / ds[1];
                        double num = input.getDouble(i, j, k, l++);
                        double zzGet = zz.getDouble(i, j, zk, zl);
                        zz.putScalar(new int[]{i, j, zk, zl}, Math.max(num, zzGet));
                    }
                }
            }
        }
        return zz.reshape(signalNDArray.shape());
    }

    public static INDArray downSample(INDArray d1, int[] stride) {
        int i;
        INDArray d = Nd4j.ones(stride);
        d.divi(ArrayUtil.prod(stride));
        if (stride.length != d1.shape().length) {
            int delta;
            if (stride.length > d1.shape().length) {
                int[] newShape = new int[stride.length];
                Arrays.fill(newShape, 1);
                delta = Math.abs(d.shape().length - newShape.length);
                for (i = newShape.length - 1; i >= delta; --i) {
                    newShape[i] = d.shape()[i - delta];
                }
                d1 = d1.reshape(newShape);
            } else {
                int[] newStride = new int[d1.shape().length];
                Arrays.fill(newStride, 1);
                delta = Math.abs(d.shape().length - newStride.length);
                for (i = newStride.length - 1; i >= delta; --i) {
                    newStride[i] = d.shape()[i - delta];
                }
                d = d.reshape(newStride);
            }
        }
        INDArray ret = Convolution.convn(d1, d, Convolution.Type.VALID);
        INDArrayIndex[] indices = new INDArrayIndex[d1.shape().length];
        for (i = 0; i < indices.length; ++i) {
            indices[i] = i < stride.length ? NDArrayIndex.interval(0, stride[i], d1.size(i), true) : NDArrayIndex.interval(0, d1.size(i), true);
        }
        ret = ret.get(indices);
        return ret;
    }

    public static INDArray avgPooling(INDArray toPool, int[] stride) {
        int nDims = toPool.shape().length;
        assert (nDims >= 3) : "NDArray must have 3 dimensions";
        int nRows = toPool.shape()[nDims - 2];
        int nCols = toPool.shape()[nDims - 1];
        int yStride = stride[0];
        int xStride = stride[1];
        INDArray blocks = Nd4j.create(toPool.shape());
        int iR = 0;
        while ((double)iR < Math.ceil(nRows / yStride)) {
            INDArrayIndex rows = NDArrayIndex.interval(iR * yStride, iR * yStride, true);
            int jC = 0;
            while ((double)jC < Math.ceil(nCols / xStride)) {
                INDArrayIndex cols = NDArrayIndex.interval(jC * xStride, jC * xStride + 1, true);
                INDArray blockVal = toPool.get(rows, cols).sum(toPool.shape().length - 1).mean(toPool.shape().length - 1);
                blocks.put(new INDArrayIndex[]{rows, cols}, blockVal.permute(1, 2, 0)).repmat(rows.length(), cols.length());
                ++jC;
            }
            ++iR;
        }
        return blocks;
    }

    public static INDArray sumPooling(INDArray toPool, int[] stride) {
        int nDims = toPool.shape().length;
        assert (nDims >= 3) : "NDArray must have 3 dimensions";
        int nRows = toPool.shape()[nDims - 2];
        int nCols = toPool.shape()[nDims - 1];
        int yStride = stride[0];
        int xStride = stride[1];
        INDArray blocks = Nd4j.create(toPool.shape());
        int iR = 0;
        while ((double)iR < Math.ceil(nRows / yStride)) {
            INDArrayIndex rows = NDArrayIndex.interval(iR * yStride, iR * yStride, true);
            int jC = 0;
            while ((double)jC < Math.ceil(nCols / xStride)) {
                INDArrayIndex cols = NDArrayIndex.interval(jC * xStride, jC * xStride + 1, true);
                INDArray blockVal = toPool.get(rows, cols).sum(toPool.shape().length - 1).sum(toPool.shape().length - 1);
                blocks.put(new INDArrayIndex[]{rows, cols}, blockVal.permute(1, 2, 0)).repmat(rows.length(), cols.length());
                ++jC;
            }
            ++iR;
        }
        return blocks;
    }

    public static INDArray upSample(INDArray d, INDArray scale) {
        ArrayList<INDArray> idx = new ArrayList<INDArray>();
        for (int i = 0; i < d.shape().length; ++i) {
            INDArray tmp = Nd4j.zeros(d.size(i) * (int)scale.getDouble(i), 1);
            int[] indices = ArrayUtil.range(0, (int)scale.getDouble(i) * d.size(i), (int)scale.getDouble(i));
            NDArrayIndex index = new NDArrayIndex(indices);
            tmp.put((INDArrayIndex[])new NDArrayIndex[]{index}, 1);
            INDArray put = tmp.cumsum(0);
            idx.add(put.sub(1));
        }
        INDArray ret = Nd4j.create(ArrayUtil.toInts(ArrayUtil.toNDArray(d.shape()).muli(scale)));
        INDArray retLinear = ret.linearView();
        for (int i = 0; i < retLinear.length(); ++i) {
            for (int j = 0; j < ((INDArray)idx.get(0)).length(); ++j) {
                int slice = ((INDArray)idx.get(0)).getInt(j);
                for (int k = 1; k < idx.size(); ++k) {
                }
            }
        }
        return ret;
    }

    public static double cosineSim(INDArray d1, INDArray d2) {
        return Nd4j.getExecutioner().execAndReturn(new CosineSimilarity(d1, d2, d1.length())).currentResult().doubleValue();
    }

    public static INDArray normalizeZeroMeanAndUnitVariance(INDArray toNormalize) {
        INDArray columnMeans = toNormalize.mean(0);
        INDArray columnStds = toNormalize.std(0);
        toNormalize.subiRowVector(columnMeans);
        columnStds.addi(Nd4j.EPS_THRESHOLD);
        toNormalize.diviRowVector(columnStds);
        return toNormalize;
    }

    public static INDArray unitVec(INDArray toScale) {
        double length = toScale.norm2Number().doubleValue();
        if (length > 0.0) {
            if (toScale.data().dataType() == DataBuffer.Type.FLOAT) {
                return Nd4j.getBlasWrapper().scal(1.0f / (float)length, toScale);
            }
            return Nd4j.getBlasWrapper().scal(1.0 / length, toScale);
        }
        return toScale;
    }

    public static INDArray neg(INDArray ndArray) {
        return Transforms.neg(ndArray, Nd4j.copyOnOps);
    }

    public static INDArray floor(INDArray ndArray) {
        return Transforms.floor(ndArray, Nd4j.copyOnOps);
    }

    public static INDArray ceiling(INDArray ndArray) {
        return Transforms.ceiling(ndArray, Nd4j.copyOnOps);
    }

    public static INDArray ceiling(INDArray ndArray, boolean copyOnOps) {
        return Transforms.exec(copyOnOps ? new Ceil(ndArray, ndArray.dup()) : new Ceil(ndArray, ndArray));
    }

    public static INDArray sign(INDArray toSign) {
        return Transforms.sign(toSign, Nd4j.copyOnOps);
    }

    public static INDArray stabilize(INDArray ndArray, double k) {
        return Transforms.stabilize(ndArray, k, Nd4j.copyOnOps);
    }

    public static INDArray abs(INDArray ndArray) {
        return Transforms.abs(ndArray, true);
    }

    public static INDArray exp(INDArray ndArray) {
        return Transforms.exp(ndArray, Nd4j.copyOnOps);
    }

    public static INDArray hardTanh(INDArray ndArray) {
        return Transforms.hardTanh(ndArray, Nd4j.copyOnOps);
    }

    public static INDArray identity(INDArray ndArray) {
        return Transforms.identity(ndArray, Nd4j.copyOnOps);
    }

    public static INDArray pow(INDArray ndArray, Number power) {
        return Transforms.pow(ndArray, power, Nd4j.copyOnOps);
    }

    public static INDArray round(INDArray ndArray) {
        return Transforms.round(ndArray, Nd4j.copyOnOps);
    }

    public static INDArray sigmoid(INDArray ndArray) {
        return Transforms.sigmoid(ndArray, Nd4j.copyOnOps);
    }

    public static INDArray sqrt(INDArray ndArray) {
        return Transforms.sqrt(ndArray, Nd4j.copyOnOps);
    }

    public static INDArray tanh(INDArray ndArray) {
        return Transforms.tanh(ndArray, Nd4j.copyOnOps);
    }

    public static INDArray log(INDArray ndArray) {
        return Transforms.log(ndArray, Nd4j.copyOnOps);
    }

    public static INDArray eps(INDArray ndArray) {
        return Transforms.eps(ndArray, Nd4j.copyOnOps);
    }

    public static INDArray greaterThanOrEqual(INDArray first, INDArray ndArray) {
        return Transforms.greaterThanOrEqual(first, ndArray, Nd4j.copyOnOps);
    }

    public static INDArray lessThanOrEqual(INDArray first, INDArray ndArray) {
        return Transforms.lessThanOrEqual(first, ndArray, Nd4j.copyOnOps);
    }

    public static INDArray lessThanOrEqual(INDArray first, INDArray ndArray, boolean dup) {
        return Transforms.exec(dup ? new LessThanOrEqual(first.dup(), ndArray) : new LessThanOrEqual(first, ndArray));
    }

    public static INDArray greaterThanOrEqual(INDArray first, INDArray ndArray, boolean dup) {
        return Transforms.exec(dup ? new GreaterThanOrEqual(first.dup(), ndArray) : new GreaterThanOrEqual(first, ndArray));
    }

    public static INDArray eps(INDArray ndArray, boolean dup) {
        return Transforms.exec(dup ? new Eps(ndArray.dup()) : new Eps(ndArray));
    }

    public static INDArray floor(INDArray ndArray, boolean dup) {
        return Transforms.exec(dup ? new Floor(ndArray.dup()) : new Floor(ndArray));
    }

    public static INDArray sign(INDArray toSign, boolean dup) {
        return Transforms.exec(dup ? new Sign(toSign, toSign.dup()) : new Sign(toSign));
    }

    public static INDArray max(INDArray ndArray, double k, boolean dup) {
        return Transforms.exec(dup ? new ScalarMax(ndArray.dup(), k) : new ScalarMax(ndArray, k));
    }

    public static INDArray max(INDArray ndArray, double k) {
        return Transforms.max(ndArray, k, Nd4j.copyOnOps);
    }

    public static INDArray stabilize(INDArray ndArray, double k, boolean dup) {
        return Transforms.exec(dup ? new Stabilize(ndArray, ndArray.dup(), k) : new Stabilize(ndArray, k));
    }

    public static INDArray abs(INDArray ndArray, boolean dup) {
        return Transforms.exec(dup ? new Abs(ndArray, ndArray.dup()) : new Abs(ndArray));
    }

    public static INDArray exp(INDArray ndArray, boolean dup) {
        return Transforms.exec(dup ? new Exp(ndArray, ndArray.dup()) : new Exp(ndArray));
    }

    public static INDArray hardTanh(INDArray ndArray, boolean dup) {
        return Transforms.exec(dup ? new HardTanh(ndArray, ndArray.dup()) : new HardTanh(ndArray));
    }

    public static INDArray identity(INDArray ndArray, boolean dup) {
        return Transforms.exec(dup ? new Identity(ndArray, ndArray.dup()) : new Identity(ndArray));
    }

    public static INDArray pow(INDArray ndArray, Number power, boolean dup) {
        return Transforms.exec(dup ? new Pow(ndArray, ndArray.dup(), power.doubleValue()) : new Pow(ndArray, power.doubleValue()));
    }

    public static INDArray round(INDArray ndArray, boolean dup) {
        return Transforms.exec(dup ? new Round(ndArray, ndArray.dup()) : new Round(ndArray));
    }

    public static INDArray sigmoid(INDArray ndArray, boolean dup) {
        return Transforms.exec(dup ? new Sigmoid(ndArray, ndArray.dup()) : new Sigmoid(ndArray));
    }

    public static INDArray sqrt(INDArray ndArray, boolean dup) {
        return Transforms.exec(dup ? new Sqrt(ndArray, ndArray.dup()) : new Sqrt(ndArray));
    }

    public static INDArray tanh(INDArray ndArray, boolean dup) {
        return Transforms.exec(dup ? new Tanh(ndArray, ndArray.dup()) : new Tanh(ndArray));
    }

    public static INDArray log(INDArray ndArray, boolean dup) {
        return Transforms.exec(dup ? new Log(ndArray, ndArray.dup()) : new Log(ndArray));
    }

    public static INDArray neg(INDArray ndArray, boolean dup) {
        return Transforms.exec(dup ? new Negative(ndArray, ndArray.dup()) : new Negative(ndArray));
    }

    private static INDArray exec(ScalarOp op) {
        if (op.x().isCleanedUp()) {
            throw new IllegalStateException("NDArray already freed");
        }
        return Nd4j.getExecutioner().exec(op).z();
    }

    private static INDArray exec(TransformOp op) {
        if (op.x().isCleanedUp()) {
            throw new IllegalStateException("NDArray already freed");
        }
        return Nd4j.getExecutioner().execAndReturn(op);
    }
}

