/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Deconvolution2D;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class DeconvolutionParamInitializer
extends ConvolutionParamInitializer {
    private static final DeconvolutionParamInitializer INSTANCE = new DeconvolutionParamInitializer();

    public static DeconvolutionParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    protected INDArray createWeightMatrix(NeuralNetConfiguration conf, INDArray weightView, boolean initializeParams) {
        Deconvolution2D layerConf = (Deconvolution2D)conf.getLayer();
        if (initializeParams) {
            int[] kernel = layerConf.getKernelSize();
            int[] stride = layerConf.getStride();
            long inputDepth = layerConf.getNIn();
            long outputDepth = layerConf.getNOut();
            double fanIn = inputDepth * (long)kernel[0] * (long)kernel[1];
            double fanOut = (double)(outputDepth * (long)kernel[0] * (long)kernel[1]) / ((double)stride[0] * (double)stride[1]);
            long[] weightsShape = new long[]{inputDepth, outputDepth, kernel[0], kernel[1]};
            INDArray weights = layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', weightView);
            return weights;
        }
        int[] kernel = layerConf.getKernelSize();
        INDArray weights = WeightInitUtil.reshapeWeights(new long[]{layerConf.getNIn(), layerConf.getNOut(), kernel[0], kernel[1]}, weightView, 'c');
        return weights;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        Deconvolution2D layerConf = (Deconvolution2D)conf.getLayer();
        int[] kernel = layerConf.getKernelSize();
        long nIn = layerConf.getNIn();
        long nOut = layerConf.getNOut();
        LinkedHashMap<String, INDArray> out = new LinkedHashMap<String, INDArray>();
        if (layerConf.hasBias()) {
            INDArray biasGradientView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)0L, (long)nOut)});
            INDArray weightGradientView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)nOut, (long)this.numParams(conf))}).reshape('c', new long[]{nIn, nOut, kernel[0], kernel[1]});
            out.put("b", biasGradientView);
            out.put("W", weightGradientView);
        } else {
            INDArray weightGradientView = gradientView.reshape('c', new long[]{nIn, nOut, kernel[0], kernel[1]});
            out.put("W", weightGradientView);
        }
        return out;
    }
}

