/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class SimpleRnnParamInitializer
implements ParamInitializer {
    private static final SimpleRnnParamInitializer INSTANCE = new SimpleRnnParamInitializer();
    public static final String WEIGHT_KEY = "W";
    public static final String RECURRENT_WEIGHT_KEY = "RW";
    public static final String BIAS_KEY = "b";
    private static final List<String> PARAM_KEYS = Collections.unmodifiableList(Arrays.asList("W", "RW", "b"));
    private static final List<String> WEIGHT_KEYS = Collections.unmodifiableList(Arrays.asList("W", "RW"));
    private static final List<String> BIAS_KEYS = Collections.singletonList("b");

    public static SimpleRnnParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public long numParams(NeuralNetConfiguration conf) {
        return this.numParams(conf.getLayer());
    }

    @Override
    public long numParams(Layer layer) {
        SimpleRnn c = (SimpleRnn)layer;
        long nIn = c.getNIn();
        long nOut = c.getNOut();
        return nIn * nOut + nOut * nOut + nOut;
    }

    @Override
    public List<String> paramKeys(Layer layer) {
        return PARAM_KEYS;
    }

    @Override
    public List<String> weightKeys(Layer layer) {
        return WEIGHT_KEYS;
    }

    @Override
    public List<String> biasKeys(Layer layer) {
        return BIAS_KEYS;
    }

    @Override
    public boolean isWeightParam(Layer layer, String key) {
        return WEIGHT_KEY.equals(key) || RECURRENT_WEIGHT_KEY.equals(key);
    }

    @Override
    public boolean isBiasParam(Layer layer, String key) {
        return BIAS_KEY.equals(key);
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        Map<String, INDArray> m;
        SimpleRnn c = (SimpleRnn)conf.getLayer();
        long nIn = c.getNIn();
        long nOut = c.getNOut();
        if (initializeParams) {
            WeightInit rwInit;
            Distribution dist = Distributions.createDistribution(c.getDist());
            m = SimpleRnnParamInitializer.getSubsets(paramsView, nIn, nOut, false);
            INDArray w = WeightInitUtil.initWeights((double)nIn, (double)nOut, new long[]{nIn, nOut}, c.getWeightInit(), dist, 'f', m.get(WEIGHT_KEY));
            m.put(WEIGHT_KEY, w);
            Distribution rwDist = dist;
            if (c.getWeightInitRecurrent() != null) {
                rwInit = c.getWeightInitRecurrent();
                if (c.getDistRecurrent() != null) {
                    rwDist = Distributions.createDistribution(c.getDistRecurrent());
                }
            } else {
                rwInit = c.getWeightInit();
            }
            INDArray rw = WeightInitUtil.initWeights((double)nOut, (double)nOut, new long[]{nOut, nOut}, rwInit, rwDist, 'f', m.get(RECURRENT_WEIGHT_KEY));
            m.put(RECURRENT_WEIGHT_KEY, rw);
            m.get(BIAS_KEY).assign((Number)c.getBiasInit());
        } else {
            m = SimpleRnnParamInitializer.getSubsets(paramsView, nIn, nOut, true);
        }
        conf.addVariable(WEIGHT_KEY);
        conf.addVariable(RECURRENT_WEIGHT_KEY);
        conf.addVariable(BIAS_KEY);
        return m;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        SimpleRnn c = (SimpleRnn)conf.getLayer();
        long nIn = c.getNIn();
        long nOut = c.getNOut();
        return SimpleRnnParamInitializer.getSubsets(gradientView, nIn, nOut, true);
    }

    private static Map<String, INDArray> getSubsets(INDArray in, long nIn, long nOut, boolean reshape) {
        long pos = nIn * nOut;
        INDArray w = in.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)0L, (long)pos)});
        INDArray rw = in.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)pos, (long)(pos + nOut * nOut))});
        INDArray b = in.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)(pos += nOut * nOut), (long)(pos + nOut))});
        if (reshape) {
            w = w.reshape('f', new long[]{nIn, nOut});
            rw = rw.reshape('f', new long[]{nOut, nOut});
        }
        LinkedHashMap<String, INDArray> m = new LinkedHashMap<String, INDArray>();
        m.put(WEIGHT_KEY, w);
        m.put(RECURRENT_WEIGHT_KEY, rw);
        m.put(BIAS_KEY, b);
        return m;
    }
}

