/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class ElementWiseParamInitializer
extends DefaultParamInitializer {
    private static final ElementWiseParamInitializer INSTANCE = new ElementWiseParamInitializer();

    public static ElementWiseParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public long numParams(Layer layer) {
        FeedForwardLayer layerConf = (FeedForwardLayer)layer;
        long nIn = layerConf.getNIn();
        return nIn * 2L;
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        long nIn;
        if (!(conf.getLayer() instanceof FeedForwardLayer)) {
            throw new IllegalArgumentException("unsupported layer type: " + conf.getLayer().getClass().getName());
        }
        Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap());
        long length = this.numParams(conf);
        if (paramsView.length() != length) {
            throw new IllegalStateException("Expected params view of length " + length + ", got length " + paramsView.length());
        }
        FeedForwardLayer layerConf = (FeedForwardLayer)conf.getLayer();
        long nWeightParams = nIn = layerConf.getNIn();
        INDArray weightView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)0L, (long)nWeightParams)});
        INDArray biasView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)nWeightParams, (long)(nWeightParams + nIn))});
        params.put("W", this.createWeightMatrix(conf, weightView, initializeParams));
        params.put("b", this.createBias(conf, biasView, initializeParams));
        conf.addVariable("W");
        conf.addVariable("b");
        return params;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        FeedForwardLayer layerConf = (FeedForwardLayer)conf.getLayer();
        long nIn = layerConf.getNIn();
        long nOut = layerConf.getNOut();
        long nWeightParams = nIn;
        INDArray weightGradientView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)0L, (long)nWeightParams)});
        INDArray biasView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)nWeightParams, (long)(nWeightParams + nOut))});
        LinkedHashMap<String, INDArray> out = new LinkedHashMap<String, INDArray>();
        out.put("W", weightGradientView);
        out.put("b", biasView);
        return out;
    }

    @Override
    protected INDArray createWeightMatrix(long nIn, long nOut, IWeightInit weightInit, INDArray weightParamView, boolean initializeParameters) {
        long[] shape = new long[]{1L, nIn};
        if (initializeParameters) {
            INDArray ret = weightInit.init(nIn, nOut, shape, 'f', weightParamView);
            return ret;
        }
        return weightParamView;
    }
}

