/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class SimpleRnnParamInitializer
implements ParamInitializer {
    private static final SimpleRnnParamInitializer INSTANCE = new SimpleRnnParamInitializer();
    public static final String WEIGHT_KEY = "W";
    public static final String RECURRENT_WEIGHT_KEY = "RW";
    public static final String BIAS_KEY = "b";
    public static final String GAIN_KEY = "g";
    private static final List<String> WEIGHT_KEYS = Collections.unmodifiableList(Arrays.asList("W", "RW"));
    private static final List<String> BIAS_KEYS = Collections.singletonList("b");

    public static SimpleRnnParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public long numParams(NeuralNetConfiguration conf) {
        return this.numParams(conf.getLayer());
    }

    @Override
    public long numParams(Layer layer) {
        SimpleRnn c = (SimpleRnn)layer;
        long nIn = c.getNIn();
        long nOut = c.getNOut();
        return nIn * nOut + nOut * nOut + nOut + (this.hasLayerNorm(layer) ? 2L * nOut : 0L);
    }

    @Override
    public List<String> paramKeys(Layer layer) {
        ArrayList<String> keys = new ArrayList<String>(3);
        keys.addAll(this.weightKeys(layer));
        keys.addAll(this.biasKeys(layer));
        return keys;
    }

    @Override
    public List<String> weightKeys(Layer layer) {
        ArrayList<String> keys = new ArrayList<String>(WEIGHT_KEYS);
        if (this.hasLayerNorm(layer)) {
            keys.add(GAIN_KEY);
        }
        return keys;
    }

    @Override
    public List<String> biasKeys(Layer layer) {
        return BIAS_KEYS;
    }

    @Override
    public boolean isWeightParam(Layer layer, String key) {
        return WEIGHT_KEY.equals(key) || RECURRENT_WEIGHT_KEY.equals(key) || GAIN_KEY.equals(key);
    }

    @Override
    public boolean isBiasParam(Layer layer, String key) {
        return BIAS_KEY.equals(key);
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        Map<String, INDArray> m;
        SimpleRnn c = (SimpleRnn)conf.getLayer();
        long nIn = c.getNIn();
        long nOut = c.getNOut();
        if (initializeParams) {
            m = SimpleRnnParamInitializer.getSubsets(paramsView, nIn, nOut, false, this.hasLayerNorm(c));
            INDArray w = c.getWeightInitFn().init(nIn, nOut, new long[]{nIn, nOut}, 'f', m.get(WEIGHT_KEY));
            m.put(WEIGHT_KEY, w);
            IWeightInit rwInit = c.getWeightInitFnRecurrent() != null ? c.getWeightInitFnRecurrent() : c.getWeightInitFn();
            INDArray rw = rwInit.init(nOut, nOut, new long[]{nOut, nOut}, 'f', m.get(RECURRENT_WEIGHT_KEY));
            m.put(RECURRENT_WEIGHT_KEY, rw);
            m.get(BIAS_KEY).assign((Number)c.getBiasInit());
            if (this.hasLayerNorm(c)) {
                m.get(GAIN_KEY).assign((Number)c.getGainInit());
            }
        } else {
            m = SimpleRnnParamInitializer.getSubsets(paramsView, nIn, nOut, true, this.hasLayerNorm(c));
        }
        conf.addVariable(WEIGHT_KEY);
        conf.addVariable(RECURRENT_WEIGHT_KEY);
        conf.addVariable(BIAS_KEY);
        if (this.hasLayerNorm(c)) {
            conf.addVariable(GAIN_KEY);
        }
        return m;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        SimpleRnn c = (SimpleRnn)conf.getLayer();
        long nIn = c.getNIn();
        long nOut = c.getNOut();
        return SimpleRnnParamInitializer.getSubsets(gradientView, nIn, nOut, true, this.hasLayerNorm(c));
    }

    private static Map<String, INDArray> getSubsets(INDArray in, long nIn, long nOut, boolean reshape, boolean hasLayerNorm) {
        long pos = nIn * nOut;
        INDArray w = in.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)0L, (long)pos)});
        INDArray rw = in.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)pos, (long)(pos + nOut * nOut))});
        INDArray b = in.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)(pos += nOut * nOut), (long)(pos + nOut))});
        if (reshape) {
            w = w.reshape('f', new long[]{nIn, nOut});
            rw = rw.reshape('f', new long[]{nOut, nOut});
        }
        LinkedHashMap<String, INDArray> m = new LinkedHashMap<String, INDArray>();
        m.put(WEIGHT_KEY, w);
        m.put(RECURRENT_WEIGHT_KEY, rw);
        m.put(BIAS_KEY, b);
        if (hasLayerNorm) {
            INDArray g = in.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)(pos += nOut), (long)(pos + 2L * nOut))});
            m.put(GAIN_KEY, g);
        }
        return m;
    }

    protected boolean hasLayerNorm(Layer layer) {
        if (layer instanceof SimpleRnn) {
            return ((SimpleRnn)layer).hasLayerNorm();
        }
        return false;
    }
}

