/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class SeparableConvolutionParamInitializer
implements ParamInitializer {
    private static final SeparableConvolutionParamInitializer INSTANCE = new SeparableConvolutionParamInitializer();
    public static final String DEPTH_WISE_WEIGHT_KEY = "W";
    public static final String POINT_WISE_WEIGHT_KEY = "pW";
    public static final String BIAS_KEY = "b";

    public static SeparableConvolutionParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public int numParams(NeuralNetConfiguration conf) {
        return this.numParams(conf.getLayer());
    }

    @Override
    public int numParams(Layer l) {
        SeparableConvolution2D layerConf = (SeparableConvolution2D)l;
        int depthWiseParams = this.numDepthWiseParams(layerConf);
        int pointWiseParams = this.numPointWiseParams(layerConf);
        int biasParams = this.numBiasParams(layerConf);
        return depthWiseParams + pointWiseParams + biasParams;
    }

    private int numBiasParams(SeparableConvolution2D layerConf) {
        int nOut = layerConf.getNOut();
        return layerConf.hasBias() ? nOut : 0;
    }

    private int numDepthWiseParams(SeparableConvolution2D layerConf) {
        int[] kernel = layerConf.getKernelSize();
        int nIn = layerConf.getNIn();
        int depthMultiplier = layerConf.getDepthMultiplier();
        return nIn * depthMultiplier * kernel[0] * kernel[1];
    }

    private int numPointWiseParams(SeparableConvolution2D layerConf) {
        int nIn = layerConf.getNIn();
        int nOut = layerConf.getNOut();
        int depthMultiplier = layerConf.getDepthMultiplier();
        return nIn * depthMultiplier * nOut;
    }

    @Override
    public List<String> paramKeys(Layer layer) {
        SeparableConvolution2D layerConf = (SeparableConvolution2D)layer;
        if (layerConf.hasBias()) {
            return Arrays.asList(DEPTH_WISE_WEIGHT_KEY, POINT_WISE_WEIGHT_KEY, BIAS_KEY);
        }
        return this.weightKeys(layer);
    }

    @Override
    public List<String> weightKeys(Layer layer) {
        return Arrays.asList(DEPTH_WISE_WEIGHT_KEY, POINT_WISE_WEIGHT_KEY);
    }

    @Override
    public List<String> biasKeys(Layer layer) {
        SeparableConvolution2D layerConf = (SeparableConvolution2D)layer;
        if (layerConf.hasBias()) {
            return Collections.singletonList(BIAS_KEY);
        }
        return Collections.emptyList();
    }

    @Override
    public boolean isWeightParam(Layer layer, String key) {
        return DEPTH_WISE_WEIGHT_KEY.equals(key) || POINT_WISE_WEIGHT_KEY.equals(key);
    }

    @Override
    public boolean isBiasParam(Layer layer, String key) {
        return BIAS_KEY.equals(key);
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        SeparableConvolution2D layer = (SeparableConvolution2D)conf.getLayer();
        if (layer.getKernelSize().length != 2) {
            throw new IllegalArgumentException("Filter size must be == 2");
        }
        Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap());
        SeparableConvolution2D layerConf = (SeparableConvolution2D)conf.getLayer();
        int depthWiseParams = this.numDepthWiseParams(layerConf);
        int biasParams = this.numBiasParams(layerConf);
        INDArray depthWiseWeightView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)biasParams, (int)(biasParams + depthWiseParams))});
        INDArray pointWiseWeightView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(biasParams + depthWiseParams), (int)this.numParams(conf))});
        params.put(DEPTH_WISE_WEIGHT_KEY, this.createDepthWiseWeightMatrix(conf, depthWiseWeightView, initializeParams));
        conf.addVariable(DEPTH_WISE_WEIGHT_KEY);
        params.put(POINT_WISE_WEIGHT_KEY, this.createPointWiseWeightMatrix(conf, pointWiseWeightView, initializeParams));
        conf.addVariable(POINT_WISE_WEIGHT_KEY);
        if (layer.hasBias()) {
            INDArray biasView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)biasParams)});
            params.put(BIAS_KEY, this.createBias(conf, biasView, initializeParams));
            conf.addVariable(BIAS_KEY);
        }
        return params;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        SeparableConvolution2D layerConf = (SeparableConvolution2D)conf.getLayer();
        int[] kernel = layerConf.getKernelSize();
        int nIn = layerConf.getNIn();
        int depthMultiplier = layerConf.getDepthMultiplier();
        int nOut = layerConf.getNOut();
        LinkedHashMap<String, INDArray> out = new LinkedHashMap<String, INDArray>();
        int depthWiseParams = this.numDepthWiseParams(layerConf);
        int biasParams = this.numBiasParams(layerConf);
        INDArray depthWiseWeightGradientView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)biasParams, (int)(biasParams + depthWiseParams))}).reshape('c', new int[]{depthMultiplier, nIn, kernel[0], kernel[1]});
        INDArray pointWiseWeightGradientView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)(biasParams + depthWiseParams), (int)this.numParams(conf))}).reshape('c', new int[]{nOut, nIn * depthMultiplier, 1, 1});
        out.put(DEPTH_WISE_WEIGHT_KEY, depthWiseWeightGradientView);
        out.put(POINT_WISE_WEIGHT_KEY, pointWiseWeightGradientView);
        if (layerConf.hasBias()) {
            INDArray biasGradientView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)nOut)});
            out.put(BIAS_KEY, biasGradientView);
        }
        return out;
    }

    protected INDArray createBias(NeuralNetConfiguration conf, INDArray biasView, boolean initializeParams) {
        SeparableConvolution2D layerConf = (SeparableConvolution2D)conf.getLayer();
        if (initializeParams) {
            biasView.assign((Number)layerConf.getBiasInit());
        }
        return biasView;
    }

    protected INDArray createDepthWiseWeightMatrix(NeuralNetConfiguration conf, INDArray weightView, boolean initializeParams) {
        SeparableConvolution2D layerConf = (SeparableConvolution2D)conf.getLayer();
        int depthMultiplier = layerConf.getDepthMultiplier();
        if (initializeParams) {
            Distribution dist = Distributions.createDistribution(layerConf.getDist());
            int[] kernel = layerConf.getKernelSize();
            int[] stride = layerConf.getStride();
            int inputDepth = layerConf.getNIn();
            double fanIn = inputDepth * kernel[0] * kernel[1];
            double fanOut = (double)(depthMultiplier * kernel[0] * kernel[1]) / ((double)stride[0] * (double)stride[1]);
            int[] weightsShape = new int[]{depthMultiplier, inputDepth, kernel[0], kernel[1]};
            return WeightInitUtil.initWeights(fanIn, fanOut, weightsShape, layerConf.getWeightInit(), dist, 'c', weightView);
        }
        int[] kernel = layerConf.getKernelSize();
        return WeightInitUtil.reshapeWeights(new int[]{depthMultiplier, layerConf.getNIn(), kernel[0], kernel[1]}, weightView, 'c');
    }

    protected INDArray createPointWiseWeightMatrix(NeuralNetConfiguration conf, INDArray weightView, boolean initializeParams) {
        SeparableConvolution2D layerConf = (SeparableConvolution2D)conf.getLayer();
        int depthMultiplier = layerConf.getDepthMultiplier();
        if (initializeParams) {
            double fanIn;
            Distribution dist = Distributions.createDistribution(layerConf.getDist());
            int inputDepth = layerConf.getNIn();
            int outputDepth = layerConf.getNOut();
            double fanOut = fanIn = (double)(inputDepth * depthMultiplier);
            int[] weightsShape = new int[]{outputDepth, depthMultiplier * inputDepth, 1, 1};
            return WeightInitUtil.initWeights(fanIn, fanOut, weightsShape, layerConf.getWeightInit(), dist, 'c', weightView);
        }
        return WeightInitUtil.reshapeWeights(new int[]{layerConf.getNOut(), depthMultiplier * layerConf.getNIn(), 1, 1}, weightView, 'c');
    }
}

