/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.util;

import java.util.Arrays;
import lombok.NonNull;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.Deconvolution2D;
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
import org.deeplearning4j.nn.conf.layers.SpaceToBatchLayer;
import org.deeplearning4j.nn.conf.layers.SpaceToDepthLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;

public class ConvolutionUtils {
    public static final String NCHW_NHWC_ERROR_MSG = "Note: Convolution layers can be configured for either NCHW (channels first) or NHWC (channels last) format for input images and activations.\nLayers can be configured using .dataFormat(CNN2DFormat.NCHW/NHWC) when constructing the layer, or for the entire net using .setInputType(InputType.convolutional(height, width, depth, CNN2DForman.NCHW/NHWC)).\nImageRecordReader and NativeImageLoader can also be configured to load image data in either NCHW or NHWC format which must match the network";
    private static final int[] ONES = new int[]{1, 1};

    private ConvolutionUtils() {
    }

    @Deprecated
    public static int[] getOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, ConvolutionMode convolutionMode) {
        return ConvolutionUtils.getOutputSize(inputData, kernel, strides, padding, convolutionMode, ONES);
    }

    public static int[] getDeconvolutionOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format) {
        int wDim;
        boolean nchw = format == CNN2DFormat.NCHW;
        int hDim = nchw ? 2 : 1;
        int n = wDim = nchw ? 3 : 2;
        if (inputData.size(hDim) > Integer.MAX_VALUE || inputData.size(wDim) > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        int hIn = (int)inputData.size(hDim);
        int wIn = (int)inputData.size(wDim);
        int[] eKernel = ConvolutionUtils.effectiveKernelSize(kernel, dilation);
        if (convolutionMode == ConvolutionMode.Same) {
            int hOut = strides[0] * hIn;
            int wOut = strides[1] * wIn;
            return new int[]{hOut, wOut};
        }
        int hOut = strides[0] * (hIn - 1) + eKernel[0] - 2 * padding[0];
        int wOut = strides[1] * (wIn - 1) + eKernel[1] - 2 * padding[1];
        return new int[]{hOut, wOut};
    }

    public static long[] getDeconvolution3DOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, int[] dilation, ConvolutionMode convolutionMode, Convolution3D.DataFormat dataFormat) {
        long dIn;
        long wIn;
        long hIn;
        if (dataFormat == Convolution3D.DataFormat.NCDHW) {
            hIn = inputData.size(2);
            wIn = inputData.size(3);
            dIn = inputData.size(4);
        } else {
            hIn = inputData.size(1);
            wIn = inputData.size(2);
            dIn = inputData.size(3);
        }
        int[] eKernel = ConvolutionUtils.effectiveKernelSize(kernel, dilation);
        if (convolutionMode == ConvolutionMode.Same) {
            long hOut = (long)strides[0] * hIn;
            long wOut = (long)strides[1] * wIn;
            long dOut = (long)strides[2] * dIn;
            return new long[]{hOut, wOut, dOut};
        }
        long hOut = (long)strides[0] * (hIn - 1L) + (long)eKernel[0] - (long)(2 * padding[0]);
        long wOut = (long)strides[1] * (wIn - 1L) + (long)eKernel[1] - (long)(2 * padding[1]);
        long dOut = (long)strides[2] * (dIn - 1L) + (long)eKernel[2] - (long)(2 * padding[2]);
        return new long[]{hOut, wOut, dOut};
    }

    @Deprecated
    public static int[] getOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, ConvolutionMode convolutionMode, int[] dilation) {
        return ConvolutionUtils.getOutputSize(inputData, kernel, strides, padding, convolutionMode, dilation, CNN2DFormat.NCHW);
    }

    public static boolean layerHasConvolutionLayout(Layer layer) {
        return layer instanceof ConvolutionLayer || layer instanceof SubsamplingLayer || layer instanceof SpaceToBatchLayer || layer instanceof Upsampling2D || layer instanceof SpaceToDepthLayer || layer instanceof ZeroPaddingLayer || layer instanceof SeparableConvolution2D || layer instanceof Deconvolution2D || layer instanceof Cropping2D || layer instanceof DepthwiseConvolution2D;
    }

    public static CNN2DFormat getFormatForLayer(Layer layer) {
        if (layer instanceof Convolution1DLayer) {
            Convolution1DLayer convolution1DLayer = (Convolution1DLayer)layer;
            return convolution1DLayer.getCnn2dDataFormat();
        }
        if (layer instanceof ConvolutionLayer) {
            ConvolutionLayer convolutionLayer = (ConvolutionLayer)layer;
            return convolutionLayer.getCnn2dDataFormat();
        }
        if (layer instanceof SubsamplingLayer) {
            SubsamplingLayer subsamplingLayer = (SubsamplingLayer)layer;
            return subsamplingLayer.getCnn2dDataFormat();
        }
        if (layer instanceof SpaceToBatchLayer) {
            SpaceToBatchLayer spaceToBatchLayer = (SpaceToBatchLayer)layer;
            return spaceToBatchLayer.getFormat();
        }
        if (layer instanceof Upsampling2D) {
            Upsampling2D upsampling2D = (Upsampling2D)layer;
            return upsampling2D.getFormat();
        }
        if (layer instanceof SpaceToDepthLayer) {
            SpaceToDepthLayer spaceToDepthLayer = (SpaceToDepthLayer)layer;
            return spaceToDepthLayer.getDataFormat();
        }
        if (layer instanceof ZeroPaddingLayer) {
            ZeroPaddingLayer zeroPaddingLayer = (ZeroPaddingLayer)layer;
            return zeroPaddingLayer.getDataFormat();
        }
        if (layer instanceof SeparableConvolution2D) {
            SeparableConvolution2D separableConvolution2D = (SeparableConvolution2D)layer;
            return separableConvolution2D.getCnn2dDataFormat();
        }
        if (layer instanceof Deconvolution2D) {
            Deconvolution2D deconvolution2D = (Deconvolution2D)layer;
            return deconvolution2D.getCnn2dDataFormat();
        }
        if (layer instanceof DepthwiseConvolution2D) {
            DepthwiseConvolution2D depthwiseConvolution2D = (DepthwiseConvolution2D)layer;
            return depthwiseConvolution2D.getCnn2dDataFormat();
        }
        if (layer instanceof Cropping2D) {
            Cropping2D cropping2D = (Cropping2D)layer;
            return cropping2D.getDataFormat();
        }
        throw new IllegalArgumentException("Illegal type given " + layer.getClass().getName());
    }

    public static PaddingMode paddingModeForConvolutionMode(ConvolutionMode convolutionMode) {
        switch (convolutionMode) {
            case Same: {
                return PaddingMode.SAME;
            }
            case Causal: {
                return PaddingMode.CAUSAL;
            }
            case Strict: 
            case Truncate: {
                return PaddingMode.VALID;
            }
        }
        throw new IllegalArgumentException("Invalid input convolution mode: " + (Object)((Object)convolutionMode));
    }

    public static int[] getOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format) {
        int hDim = 2;
        int wDim = 3;
        if (format == CNN2DFormat.NHWC) {
            hDim = 1;
            wDim = 2;
        }
        if (inputData.size(hDim) > Integer.MAX_VALUE || inputData.size(wDim) > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        int inH = (int)inputData.size(hDim);
        int inW = (int)inputData.size(wDim);
        int[] eKernel = ConvolutionUtils.effectiveKernelSize(kernel, dilation);
        boolean atrous = eKernel == kernel;
        int[] inShape = new int[]{inH, inW};
        ConvolutionUtils.validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inShape, atrous);
        if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) {
            int outH = (int)Math.ceil((double)inH / (double)strides[0]);
            int outW = (int)Math.ceil((double)inW / (double)strides[1]);
            return new int[]{outH, outW};
        }
        int hOut = (inH - eKernel[0] + 2 * padding[0]) / strides[0] + 1;
        int wOut = (inW - eKernel[1] + 2 * padding[1]) / strides[1] + 1;
        return new int[]{hOut, wOut};
    }

    public static void validateShapes(INDArray inputData, int[] eKernel, int[] strides, int[] padding, ConvolutionMode convolutionMode, int[] dilation, int[] inShape, boolean atrous) {
        boolean t;
        int inH = inShape[0];
        int inW = inShape[1];
        boolean bl = t = convolutionMode == ConvolutionMode.Truncate;
        if (t && (eKernel[0] <= 0 || eKernel[0] > inH + 2 * padding[0])) {
            StringBuilder sb = new StringBuilder();
            sb.append("Invalid input data or configuration: ");
            if (atrous) {
                sb.append("effective ");
            }
            sb.append("kernel height and input height must satisfy 0 < ");
            if (atrous) {
                sb.append("effective ");
            }
            sb.append("kernel height <= input height + 2 * padding height. \nGot ");
            if (atrous) {
                sb.append("effective ");
            }
            sb.append("kernel height = ").append(eKernel[0]).append(", input height = ").append(inH).append(" and padding height = ").append(padding[0]).append(" which do not satisfy 0 < ").append(eKernel[0]).append(" <= ").append(inH + 2 * padding[0]).append(ConvolutionUtils.getCommonErrorMsg(inputData, eKernel, strides, padding, dilation));
            throw new DL4JInvalidInputException(sb.toString());
        }
        if (t && (eKernel[1] <= 0 || eKernel[1] > inW + 2 * padding[1])) {
            StringBuilder sb = new StringBuilder();
            sb.append("Invalid input data or configuration: ");
            if (atrous) {
                sb.append("effective ");
            }
            sb.append("kernel width and input width must satisfy  0 < kernel width <= input width + 2 * padding width. ");
            sb.append("\nGot ");
            if (atrous) {
                sb.append("effective ");
            }
            sb.append("kernel width = ").append(eKernel[1]).append(", input width = ").append(inW).append(" and padding width = ").append(padding[1]).append(" which do not satisfy 0 < ").append(eKernel[1]).append(" <= ").append(inW + 2 * padding[1]).append("\nInput size: [numExamples,inputDepth,inputHeight,inputWidth]=").append(Arrays.toString(inputData.shape())).append(ConvolutionUtils.getCommonErrorMsg(inputData, eKernel, strides, padding, dilation));
            throw new DL4JInvalidInputException(sb.toString());
        }
        if (eKernel.length == 3 && t && (eKernel[2] <= 0 || eKernel[2] > inShape[2] + 2 * padding[2])) {
            int inD = inShape[2];
            StringBuilder sb = new StringBuilder();
            sb.append("Invalid input data or configuration: ");
            if (atrous) {
                sb.append("effective ");
            }
            sb.append("kernel channels and input channels must satisfy 0 < ");
            if (atrous) {
                sb.append("effective ");
            }
            sb.append("kernel channels <= input channels + 2 * padding channels. \nGot ");
            if (atrous) {
                sb.append("effective ");
            }
            sb.append("kernel channels = ").append(eKernel[2]).append(", input channels = ").append(inD).append(" and padding height = ").append(padding[2]).append(" which do not satisfy 0 < ").append(eKernel[2]).append(" <= ").append(inD + 2 * padding[2]).append(ConvolutionUtils.getCommonErrorMsg(inputData, eKernel, strides, padding, dilation));
            throw new DL4JInvalidInputException(sb.toString());
        }
        if (convolutionMode == ConvolutionMode.Strict) {
            if ((inH - eKernel[0] + 2 * padding[0]) % strides[0] != 0) {
                double d = (double)(inH - eKernel[0] + 2 * padding[0]) / (double)strides[0] + 1.0;
                String str = String.format("%.2f", d);
                int truncated = (int)d;
                int sameSize = (int)Math.ceil((double)inH / (double)strides[0]);
                StringBuilder sb = new StringBuilder();
                sb.append("Invalid input data or configuration: Combination of kernel size, stride and padding are not valid for given input height, using ConvolutionMode.Strict\n").append("ConvolutionMode.Strict requires: output height = (input height - kernelSize + 2*padding)/stride + 1 to be an integer. Got: (").append(inH).append(" - ").append(eKernel[0]).append(" + 2*").append(padding[0]).append(")/").append(strides[0]).append(" + 1 = ").append(str).append("\n").append("See \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\n").append("To truncate/crop the input, such that output height = floor(").append(str).append(") = ").append(truncated).append(", use ConvolutionType.Truncate.\n").append("Alternatively use ConvolutionType.Same, which will use padding to give an output height of ceil(").append(inH).append("/").append(strides[0]).append(")=").append(sameSize).append(ConvolutionUtils.getCommonErrorMsg(inputData, eKernel, strides, padding, dilation));
                throw new DL4JInvalidConfigException(sb.toString());
            }
            if ((inW - eKernel[1] + 2 * padding[1]) % strides[1] != 0) {
                double d = (double)(inW - eKernel[1] + 2 * padding[1]) / (double)strides[1] + 1.0;
                String str = String.format("%.2f", d);
                int truncated = (int)d;
                int sameSize = (int)Math.ceil((double)inW / (double)strides[1]);
                StringBuilder sb = new StringBuilder();
                sb.append("Invalid input data or configuration: Combination of kernel size, stride and padding are not valid for given input width, using ConvolutionMode.Strict\n").append("ConvolutionMode.Strict requires: output width = (input - kernelSize + 2*padding)/stride + 1 to be an integer. Got: (").append(inW).append(" - ").append(eKernel[1]).append(" + 2*").append(padding[1]).append(")/").append(strides[1]).append(" + 1 = ").append(str).append("\n").append("See \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\n").append("To truncate/crop the input, such that output width = floor(").append(str).append(") = ").append(truncated).append(", use ConvolutionType.Truncate.\n").append("Alternatively use ConvolutionType.Same, which will use padding to give an output width of ceil(").append(inW).append("/").append(strides[1]).append(")=").append(sameSize).append(ConvolutionUtils.getCommonErrorMsg(inputData, eKernel, strides, padding, dilation));
                throw new DL4JInvalidConfigException(sb.toString());
            }
            if (eKernel.length == 3 && (inShape[2] - eKernel[2] + 2 * padding[2]) % strides[2] != 0) {
                int inD = inShape[2];
                double d = (double)(inD - eKernel[2] + 2 * padding[2]) / (double)strides[2] + 1.0;
                String str = String.format("%.2f", d);
                int truncated = (int)d;
                int sameSize = (int)Math.ceil((double)inD / (double)strides[2]);
                StringBuilder sb = new StringBuilder();
                sb.append("Invalid input data or configuration: Combination of kernel size, stride and padding are not valid for given input width, using ConvolutionMode.Strict\n").append("ConvolutionMode.Strict requires: output channels = (input - kernelSize + 2*padding)/stride + 1 to be an integer. Got: (").append(inD).append(" - ").append(eKernel[2]).append(" + 2*").append(padding[2]).append(")/").append(strides[1]).append(" + 1 = ").append(str).append("\n").append("See \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\n").append("To truncate/crop the input, such that output width = floor(").append(str).append(") = ").append(truncated).append(", use ConvolutionType.Truncate.\n").append("Alternatively use ConvolutionType.Same, which will use padding to give an output width of ceil(").append(inW).append("/").append(strides[2]).append(")=").append(sameSize).append(ConvolutionUtils.getCommonErrorMsg(inputData, eKernel, strides, padding, dilation));
                throw new DL4JInvalidConfigException(sb.toString());
            }
        }
    }

    public static int[] effectiveKernelSize(int[] kernel, int[] dilation) {
        if (kernel.length == 2) {
            if (dilation[0] == 1 && dilation[1] == 1) {
                return kernel;
            }
            return new int[]{kernel[0] + (kernel[0] - 1) * (dilation[0] - 1), kernel[1] + (kernel[1] - 1) * (dilation[1] - 1)};
        }
        if (kernel.length == 3) {
            if (dilation[0] == 1 && dilation[1] == 1 && dilation[2] == 1) {
                return kernel;
            }
            return new int[]{kernel[0] + (kernel[0] - 1) * (dilation[0] - 1), kernel[1] + (kernel[1] - 1) * (dilation[1] - 1), kernel[2] + (kernel[2] - 1) * (dilation[2] - 1)};
        }
        throw new IllegalArgumentException("Kernel size has to be either two or three, got: " + kernel.length);
    }

    private static String getCommonErrorMsg(INDArray inputData, int[] kernel, int[] strides, int[] padding, int[] dilation) {
        String s = "\nInput size: [numExamples,inputDepth,inputHeight,inputWidth]=" + Arrays.toString(inputData.shape()) + ", inputKernel=" + Arrays.toString(kernel);
        if (dilation[0] != 1 || dilation[1] != 1) {
            int[] effectiveKernel = ConvolutionUtils.effectiveKernelSize(kernel, dilation);
            s = s + ", effectiveKernelGivenDilation=" + Arrays.toString(effectiveKernel);
        }
        return s + ", strides=" + Arrays.toString(strides) + ", padding=" + Arrays.toString(padding) + ", dilation=" + Arrays.toString(dilation);
    }

    public static int[] getSameModeTopLeftPadding(int[] outSize, int[] inSize, int[] kernel, int[] strides, int[] dilation) {
        int[] eKernel = ConvolutionUtils.effectiveKernelSize(kernel, dilation);
        int[] outPad = new int[kernel.length];
        boolean allGt0 = true;
        for (int i = 0; i < kernel.length; ++i) {
            outPad[i] = ((outSize[i] - 1) * strides[i] + eKernel[i] - inSize[i]) / 2;
            allGt0 &= outPad[i] >= 0;
        }
        Preconditions.checkState((boolean)allGt0, (String)"Invalid padding values calculated: %s - layer configuration is invalid? Input size %s, output size %s, kernel %s, strides %s, dilation %s", (Object)outPad, (Object)inSize, (Object)outSize, (Object)kernel, (Object)strides, (Object)dilation);
        return outPad;
    }

    public static int[] getSameModeBottomRightPadding(int[] outSize, int[] inSize, int[] kernel, int[] strides, int[] dilation) {
        int[] eKernel = ConvolutionUtils.effectiveKernelSize(kernel, dilation);
        int[] outPad = new int[]{((outSize[0] - 1) * strides[0] + eKernel[0] - inSize[0] + 1) / 2, ((outSize[1] - 1) * strides[1] + eKernel[1] - inSize[1] + 1) / 2};
        Preconditions.checkState((outPad[0] >= 0 && outPad[1] >= 0 ? 1 : 0) != 0, (String)"Invalid padding values calculated: %s - layer configuration is invalid? Input size %s, output size %s, kernel %s, strides %s, dilation %s", (Object)outPad, (Object)inSize, (Object)outSize, (Object)kernel, (Object)strides, (Object)dilation);
        return outPad;
    }

    public static int[] getHeightAndWidth(NeuralNetConfiguration conf) {
        return ConvolutionUtils.getHeightAndWidth(((ConvolutionLayer)conf.getLayer()).getKernelSize());
    }

    public static long numFeatureMap(NeuralNetConfiguration conf) {
        return ((ConvolutionLayer)conf.getLayer()).getNOut();
    }

    public static int[] getHeightAndWidth(int[] shape) {
        if (shape.length < 2) {
            throw new IllegalArgumentException("No width and height able to be found: array must be at least length 2");
        }
        return new int[]{shape[shape.length - 1], shape[shape.length - 2]};
    }

    public static int numChannels(int[] shape) {
        if (shape.length < 4) {
            return 1;
        }
        return shape[1];
    }

    public static void validateConvolutionModePadding(ConvolutionMode mode, int[] padding) {
        if (mode == ConvolutionMode.Same) {
            boolean nullPadding = true;
            for (int i : padding) {
                if (i == 0) continue;
                nullPadding = false;
            }
            if (!nullPadding) {
                throw new IllegalArgumentException("Padding cannot be used when using the `same' convolution mode");
            }
        }
    }

    public static void validateCnnKernelStridePadding(int[] kernelSize, int[] stride, int[] padding) {
        if (kernelSize == null || kernelSize.length != 2) {
            throw new IllegalStateException("Invalid kernel size: expected int[] of length 2, got " + (kernelSize == null ? null : Arrays.toString(kernelSize)));
        }
        if (stride == null || stride.length != 2) {
            throw new IllegalStateException("Invalid stride configuration: expected int[] of length 2, got " + (stride == null ? null : Arrays.toString(stride)));
        }
        if (padding == null || padding.length != 2) {
            throw new IllegalStateException("Invalid padding configuration: expected int[] of length 2, got " + (padding == null ? null : Arrays.toString(padding)));
        }
        if (kernelSize[0] <= 0 || kernelSize[1] <= 0) {
            throw new IllegalStateException("Invalid kernel size: values must be positive (> 0) for all dimensions. Got: " + Arrays.toString(kernelSize));
        }
        if (stride[0] <= 0 || stride[1] <= 0) {
            throw new IllegalStateException("Invalid stride configuration: values must be positive (> 0) for all dimensions. Got: " + Arrays.toString(stride));
        }
        if (padding[0] < 0 || padding[1] < 0) {
            throw new IllegalStateException("Invalid padding configuration: values must be >= 0 for all dimensions. Got: " + Arrays.toString(padding));
        }
    }

    public static INDArray reshape4dTo2d(INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType type) {
        return ConvolutionUtils.reshape4dTo2d(in, CNN2DFormat.NCHW, workspaceMgr, type);
    }

    public static INDArray reshape4dTo2d(INDArray in, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr, ArrayType type) {
        if (in.rank() != 4) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 4, got rank " + in.rank() + " with shape " + Arrays.toString(in.shape()));
        }
        long[] shape = in.shape();
        if (format == CNN2DFormat.NCHW) {
            INDArray out = in.permute(new int[]{0, 2, 3, 1});
            if (out.ordering() != 'c' || !Shape.strideDescendingCAscendingF((INDArray)out)) {
                out = workspaceMgr.dup(type, out, 'c');
            }
            return workspaceMgr.leverageTo(type, out.reshape('c', new long[]{shape[0] * shape[2] * shape[3], shape[1]}));
        }
        if (in.ordering() != 'c' || !Shape.strideDescendingCAscendingF((INDArray)in)) {
            in = workspaceMgr.dup(type, in, 'c');
        }
        return workspaceMgr.leverageTo(type, in.reshape('c', new long[]{shape[0] * shape[1] * shape[2], shape[3]}));
    }

    public static INDArray reshape5dTo2d(@NonNull Convolution3D.DataFormat format, INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType type) {
        if (format == null) {
            throw new NullPointerException("format is marked non-null but is null");
        }
        Preconditions.checkState((in.rank() == 5 ? 1 : 0) != 0, (String)"Invalid input: expect NDArray with rank 5, got rank %ndRank with shape %ndShape", (Object)in, (Object)in);
        if (format != Convolution3D.DataFormat.NDHWC) {
            in = in.permute(new int[]{0, 2, 3, 4, 1});
        }
        if (in.ordering() != 'c' || !Shape.hasDefaultStridesForShape((INDArray)in)) {
            in = workspaceMgr.dup(type, in, 'c');
        }
        return workspaceMgr.leverageTo(type, in.reshape('c', new long[]{in.size(0) * in.size(1) * in.size(2) * in.size(3), in.size(4)}));
    }

    public static INDArray reshapeCnn3dMask(@NonNull Convolution3D.DataFormat format, INDArray mask, INDArray label, LayerWorkspaceMgr workspaceMgr, ArrayType type) {
        if (format == null) {
            throw new NullPointerException("format is marked non-null but is null");
        }
        if (mask == null) {
            return null;
        }
        Preconditions.checkState((mask.rank() == 5 ? 1 : 0) != 0, (String)"Expected rank 5 mask for Cnn3DLossLayer in a shape broadcastable to labels shape: got mask shape %ndShape with label shape %ndShape", (Object)mask, (Object)label);
        if (mask.equalShapes(label) || format == Convolution3D.DataFormat.NDHWC && mask.size(0) == label.size(0) && mask.size(1) == label.size(1) && mask.size(2) == label.size(2) && mask.size(3) == label.size(3) || format == Convolution3D.DataFormat.NDHWC && mask.size(0) == label.size(0) && mask.size(2) == label.size(2) && mask.size(3) == label.size(3) && mask.size(4) == label.size(4)) {
            return ConvolutionUtils.reshape5dTo2d(format, mask, workspaceMgr, type);
        }
        long[] lShape = (long[])label.shape().clone();
        int channelIdx = format == Convolution3D.DataFormat.NCDHW ? 1 : 4;
        lShape[channelIdx] = mask.size(channelIdx);
        INDArray bMask = workspaceMgr.createUninitialized(type, mask.dataType(), lShape, 'c');
        Nd4j.exec((CustomOp)new Assign(new INDArray[]{bMask, mask}, new INDArray[]{bMask}));
        return ConvolutionUtils.reshape5dTo2d(format, bMask, workspaceMgr, type);
    }

    public static INDArray reshape2dTo4d(INDArray in2d, long[] toShape, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr, ArrayType type) {
        if (in2d.rank() != 2) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2");
        }
        if (toShape.length != 4) {
            throw new IllegalArgumentException("Invalid input: expect toShape with 4 elements: got " + Arrays.toString(toShape));
        }
        if (in2d.ordering() != 'c' || !Shape.hasDefaultStridesForShape((INDArray)in2d)) {
            in2d = workspaceMgr.dup(type, in2d, 'c');
        }
        if (format == CNN2DFormat.NCHW) {
            INDArray out = in2d.reshape('c', new long[]{toShape[0], toShape[2], toShape[3], toShape[1]});
            return workspaceMgr.leverageTo(type, out.permute(new int[]{0, 3, 1, 2}));
        }
        return workspaceMgr.leverageTo(type, in2d.reshape('c', toShape));
    }

    public static INDArray reshape2dTo5d(Convolution3D.DataFormat format, INDArray in2d, long n, long d, long h, long w, long ch, LayerWorkspaceMgr workspaceMgr, ArrayType type) {
        if (in2d.rank() != 2) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2");
        }
        if (in2d.ordering() != 'c' || !Shape.hasDefaultStridesForShape((INDArray)in2d)) {
            in2d = workspaceMgr.dup(type, in2d, 'c');
        }
        INDArray ndhwc = in2d.reshape('c', new long[]{n, d, h, w, ch});
        if (format == Convolution3D.DataFormat.NDHWC) {
            return workspaceMgr.leverageTo(type, ndhwc);
        }
        return workspaceMgr.leverageTo(type, ndhwc.permute(new int[]{0, 4, 1, 2, 3}));
    }

    @Deprecated
    public static INDArray reshapeMaskIfRequired(INDArray mask, INDArray output, LayerWorkspaceMgr workspaceMgr, ArrayType type) {
        return ConvolutionUtils.reshapeMaskIfRequired(mask, output, null, workspaceMgr, type);
    }

    public static INDArray reshapeMaskIfRequired(INDArray mask, INDArray output, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr, ArrayType type) {
        if (mask == null) {
            return null;
        }
        if (mask.rank() == 2) {
            return ConvolutionUtils.adapt2dMask(mask, output, format, workspaceMgr, type);
        }
        if (mask.rank() == 3) {
            return ConvolutionUtils.reshape3dMask(mask, workspaceMgr, type);
        }
        return ConvolutionUtils.reshape4dTo2d(mask, workspaceMgr, type);
    }

    public static INDArray adapt2dMask(INDArray mask, INDArray output, @NonNull CNN2DFormat format, LayerWorkspaceMgr workspaceMgr, ArrayType type) {
        if (format == null) {
            throw new NullPointerException("format is marked non-null but is null");
        }
        if (format == CNN2DFormat.NCHW) {
            long[] s = output.shape();
            INDArray bMask = workspaceMgr.create(type, mask.dataType(), new long[]{s[0], 1L, s[2], s[3]}, 'c');
            Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastCopyOp(bMask, mask, bMask, new int[]{0, 1}));
            INDArray bMaskPermute = bMask.permute(new int[]{0, 2, 3, 1}).dup('c');
            return workspaceMgr.leverageTo(type, bMaskPermute.reshape('c', new long[]{s[0] * s[2] * s[3], 1L}));
        }
        long[] s = output.shape();
        INDArray bMask = workspaceMgr.create(type, mask.dataType(), new long[]{s[0], s[2], s[3], 1L}, 'c');
        Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastCopyOp(bMask, mask, bMask, new int[]{0, 3}));
        return workspaceMgr.leverageTo(type, bMask.reshape('c', new long[]{s[0] * s[2] * s[3], 1L}));
    }

    public static INDArray reshape3dMask(INDArray mask, LayerWorkspaceMgr workspaceMgr, ArrayType type) {
        if (mask.ordering() != 'c' || !Shape.hasDefaultStridesForShape((INDArray)mask)) {
            mask = workspaceMgr.dup(type, mask, 'c');
        }
        return mask.reshape('c', new long[]{mask.length(), 1L});
    }

    public static INDArray reshape4dMask(INDArray mask, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) {
        return ConvolutionUtils.reshape4dTo2d(mask, workspaceMgr, arrayType);
    }

    public static int[] getHWDFromInputType(InputType inputType) {
        int inDepth;
        int inW;
        int inH;
        if (inputType instanceof InputType.InputTypeConvolutional) {
            InputType.InputTypeConvolutional conv = (InputType.InputTypeConvolutional)inputType;
            if (conv.getHeight() > Integer.MAX_VALUE || conv.getWidth() > Integer.MAX_VALUE || conv.getChannels() > Integer.MAX_VALUE) {
                throw new ND4JArraySizeException();
            }
            inH = (int)conv.getHeight();
            inW = (int)conv.getWidth();
            inDepth = (int)conv.getChannels();
        } else if (inputType instanceof InputType.InputTypeConvolutionalFlat) {
            InputType.InputTypeConvolutionalFlat conv = (InputType.InputTypeConvolutionalFlat)inputType;
            if (conv.getHeight() > Integer.MAX_VALUE || conv.getWidth() > Integer.MAX_VALUE || conv.getDepth() > Integer.MAX_VALUE) {
                throw new ND4JArraySizeException();
            }
            inH = (int)conv.getHeight();
            inW = (int)conv.getWidth();
            inDepth = (int)conv.getDepth();
        } else {
            throw new IllegalStateException("Invalid input type: expected InputTypeConvolutional or InputTypeConvolutionalFlat. Got: " + inputType);
        }
        return new int[]{inH, inW, inDepth};
    }

    public static INDArray cnn1dMaskReduction(INDArray in, int kernel, int stride, int padding, int dilation, ConvolutionMode cm) {
        int[] outSize;
        Preconditions.checkState((in.rank() == 2 ? 1 : 0) != 0, (String)"Rank must be 2 for cnn1d mask array - shape ", (Object)in.shape());
        if ((cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) && stride == 1) {
            return in;
        }
        if (!Shape.hasDefaultStridesForShape((INDArray)in)) {
            in = in.dup();
        }
        INDArray reshaped4d = in.reshape(new long[]{in.size(0), 1L, in.size(1), 1L});
        int[] pad = null;
        int[] k = new int[]{kernel, 1};
        int[] s = new int[]{stride, 1};
        int[] d = new int[]{dilation, 1};
        if (cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) {
            outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, null, cm, d, CNN2DFormat.NCHW);
        } else {
            pad = new int[]{padding, 0};
            outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, pad, cm, d, CNN2DFormat.NCHW);
        }
        int outH = outSize[0];
        INDArray output = Nd4j.createUninitialized((int[])new int[]{(int)in.size(0), 1, outH, 1}, (char)'c');
        MaxPooling2D op = new MaxPooling2D(reshaped4d, output, Pooling2DConfig.builder().kH((long)k[0]).kW((long)k[1]).sH((long)s[0]).sW((long)s[1]).pH(pad == null ? 0L : (long)pad[0]).pW(pad == null ? 0L : (long)pad[1]).dH((long)d[0]).dW((long)d[1]).isSameMode(cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal).isNHWC(false).build());
        Nd4j.getExecutioner().exec((CustomOp)op);
        return output.reshape('c', new long[]{in.size(0), outH});
    }

    public static INDArray cnn2dMaskReduction(INDArray inMask, int[] kernel, int[] stride, int[] padding, int[] dilation, ConvolutionMode convolutionMode) {
        int[] d;
        int[] p;
        int[] s;
        int[] k;
        if (inMask.rank() != 4) {
            throw new IllegalStateException("Expected rank 4 mask array for 2D CNN layers. Mask arrays for 2D CNN layers must have shape [batchSize,channels,X,Y] where X = (1 or activationsHeight) and Y = (1 or activationsWidth): Got rank " + inMask.rank() + " array with shape " + Arrays.toString(inMask.shape()));
        }
        if (convolutionMode == ConvolutionMode.Same && stride[0] == 1 && stride[1] == 1) {
            return inMask;
        }
        if (inMask.size(2) == 1L && inMask.size(3) == 1L) {
            return inMask;
        }
        if (inMask.size(3) == 1L) {
            k = new int[]{kernel[0], 1};
            s = new int[]{stride[0], 1};
            p = new int[]{padding[0], 0};
            d = new int[]{dilation[0], 1};
        } else if (inMask.size(2) == 1L) {
            k = new int[]{1, kernel[1]};
            s = new int[]{1, stride[1]};
            p = new int[]{0, padding[1]};
            d = new int[]{1, dilation[1]};
        } else {
            k = kernel;
            s = stride;
            p = padding;
            d = dilation;
        }
        int[] outSize = ConvolutionUtils.getOutputSize(inMask, k, s, p, convolutionMode, d);
        boolean allEq = true;
        for (int i = 0; i < outSize.length; ++i) {
            if ((long)outSize[i] == inMask.size(i)) continue;
            allEq = false;
            break;
        }
        if (allEq) {
            return inMask;
        }
        long[] outArraySize = new long[]{inMask.size(0), inMask.size(1), outSize[0], outSize[1]};
        INDArray outMask = Nd4j.createUninitialized((DataType)inMask.dataType(), (long[])outArraySize);
        MaxPooling2D op = new MaxPooling2D(inMask, outMask, Pooling2DConfig.builder().kH((long)k[0]).kW((long)k[1]).sH((long)s[0]).sW((long)s[1]).pH((long)p[0]).pW((long)p[1]).dH((long)d[0]).dW((long)d[1]).isSameMode(convolutionMode == ConvolutionMode.Same).isNHWC(false).build());
        Nd4j.exec((CustomOp)op);
        return outMask;
    }
}

