/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class PretrainParamInitializer
extends DefaultParamInitializer {
    private static final PretrainParamInitializer INSTANCE = new PretrainParamInitializer();
    public static final String VISIBLE_BIAS_KEY = "vb";

    public static PretrainParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public long numParams(NeuralNetConfiguration conf) {
        BasePretrainNetwork layerConf = (BasePretrainNetwork)conf.getLayer();
        return super.numParams(conf) + layerConf.getNIn();
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        Map<String, INDArray> params = super.init(conf, paramsView, initializeParams);
        BasePretrainNetwork layerConf = (BasePretrainNetwork)conf.getLayer();
        long nIn = layerConf.getNIn();
        long nOut = layerConf.getNOut();
        long nWeightParams = nIn * nOut;
        INDArray visibleBiasView = paramsView.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)(nWeightParams + nOut), (long)(nWeightParams + nOut + nIn))});
        params.put(VISIBLE_BIAS_KEY, this.createVisibleBias(conf, visibleBiasView, initializeParams));
        conf.addVariable(VISIBLE_BIAS_KEY);
        return params;
    }

    protected INDArray createVisibleBias(NeuralNetConfiguration conf, INDArray visibleBiasView, boolean initializeParameters) {
        BasePretrainNetwork layerConf = (BasePretrainNetwork)conf.getLayer();
        if (initializeParameters) {
            INDArray ret = Nd4j.valueArrayOf((long[])new long[]{1L, layerConf.getNIn()}, (double)layerConf.getVisibleBiasInit());
            visibleBiasView.assign(ret);
        }
        return visibleBiasView;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        Map<String, INDArray> out = super.getGradientsFromFlattened(conf, gradientView);
        FeedForwardLayer layerConf = (FeedForwardLayer)conf.getLayer();
        long nIn = layerConf.getNIn();
        long nOut = layerConf.getNOut();
        long nWeightParams = nIn * nOut;
        INDArray vBiasView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)(nWeightParams + nOut), (long)(nWeightParams + nOut + nIn))});
        out.put(VISIBLE_BIAS_KEY, vBiasView);
        return out;
    }
}

