/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.RecursiveAutoEncoder;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;

public class RecursiveParamInitializer
extends DefaultParamInitializer {
    public static final String ENCODER_WEIGHT_KEY = "W";
    public static final String DECODER_WEIGHT_KEY = "U";
    public static final String HIDDEN_BIAS_KEY = "b";
    public static final String VISIBLE_BIAS_KEY = "vb";

    @Override
    public void init(Map<String, INDArray> params, NeuralNetConfiguration conf) {
        RecursiveAutoEncoder layerConf = (RecursiveAutoEncoder)conf.getLayer();
        Distribution dist = Distributions.createDistribution(layerConf.getDist());
        int vis = layerConf.getNIn();
        int out = vis * 2;
        params.put(ENCODER_WEIGHT_KEY, WeightInitUtil.initWeights(new int[]{vis, out}, layerConf.getWeightInit(), dist));
        params.put(DECODER_WEIGHT_KEY, WeightInitUtil.initWeights(new int[]{out, vis}, layerConf.getWeightInit(), dist));
        params.put(HIDDEN_BIAS_KEY, WeightInitUtil.initWeights(new int[]{1, out}, layerConf.getWeightInit(), dist));
        params.put(VISIBLE_BIAS_KEY, WeightInitUtil.initWeights(new int[]{1, vis}, layerConf.getWeightInit(), dist));
        conf.addVariable(ENCODER_WEIGHT_KEY);
        conf.addVariable(DECODER_WEIGHT_KEY);
        conf.addVariable(HIDDEN_BIAS_KEY);
        conf.addVariable(VISIBLE_BIAS_KEY);
    }
}

