/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.Map;
import org.canova.api.conf.Configuration;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class GravesBidirectionalLSTMParamInitializer
implements ParamInitializer {
    public static final String RECURRENT_WEIGHT_KEY_FORWARDS = "RWF";
    public static final String BIAS_KEY_FORWARDS = "bF";
    public static final String INPUT_WEIGHT_KEY_FORWARDS = "WF";
    public static final String RECURRENT_WEIGHT_KEY_BACKWARDS = "RWB";
    public static final String BIAS_KEY_BACKWARDS = "bB";
    public static final String INPUT_WEIGHT_KEY_BACKWARDS = "WB";

    @Override
    public void init(Map<String, INDArray> params, NeuralNetConfiguration conf) {
        GravesBidirectionalLSTM layerConf = (GravesBidirectionalLSTM)conf.getLayer();
        double forgetGateInit = layerConf.getForgetGateBiasInit();
        Distribution dist = Distributions.createDistribution(layerConf.getDist());
        int nL = layerConf.getNOut();
        int nLast = layerConf.getNIn();
        conf.addVariable(INPUT_WEIGHT_KEY_FORWARDS);
        conf.addVariable(RECURRENT_WEIGHT_KEY_FORWARDS);
        conf.addVariable(BIAS_KEY_FORWARDS);
        conf.addVariable(INPUT_WEIGHT_KEY_BACKWARDS);
        conf.addVariable(RECURRENT_WEIGHT_KEY_BACKWARDS);
        conf.addVariable(BIAS_KEY_BACKWARDS);
        params.put(INPUT_WEIGHT_KEY_FORWARDS, WeightInitUtil.initWeights(nLast, 4 * nL, layerConf.getWeightInit(), dist));
        params.put(RECURRENT_WEIGHT_KEY_FORWARDS, WeightInitUtil.initWeights(nL, 4 * nL + 3, layerConf.getWeightInit(), dist));
        params.put(INPUT_WEIGHT_KEY_BACKWARDS, WeightInitUtil.initWeights(nLast, 4 * nL, layerConf.getWeightInit(), dist));
        params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, WeightInitUtil.initWeights(nL, 4 * nL + 3, layerConf.getWeightInit(), dist));
        INDArray biasesForwards = Nd4j.zeros((int)1, (int)(4 * nL));
        biasesForwards.put(new INDArrayIndex[]{new NDArrayIndex(new int[]{0}), NDArrayIndex.interval((int)nL, (int)(2 * nL))}, Nd4j.ones((int)1, (int)nL).muli((Number)forgetGateInit));
        INDArray biasesBackwards = Nd4j.zeros((int)1, (int)(4 * nL));
        biasesBackwards.put(new INDArrayIndex[]{new NDArrayIndex(new int[]{0}), NDArrayIndex.interval((int)nL, (int)(2 * nL))}, Nd4j.ones((int)1, (int)nL).muli((Number)forgetGateInit));
        params.put(BIAS_KEY_FORWARDS, biasesForwards);
        params.put(BIAS_KEY_BACKWARDS, biasesBackwards);
        params.get(INPUT_WEIGHT_KEY_FORWARDS).data().persist();
        params.get(RECURRENT_WEIGHT_KEY_FORWARDS).data().persist();
        params.get(INPUT_WEIGHT_KEY_BACKWARDS).data().persist();
        params.get(RECURRENT_WEIGHT_KEY_BACKWARDS).data().persist();
        params.get(BIAS_KEY_FORWARDS).data().persist();
        params.get(BIAS_KEY_BACKWARDS).data().persist();
    }

    @Override
    public void init(Map<String, INDArray> params, NeuralNetConfiguration conf, Configuration extraConf) {
        this.init(params, conf);
    }
}

