/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class BatchNormalizationParamInitializer
implements ParamInitializer {
    private static final BatchNormalizationParamInitializer INSTANCE = new BatchNormalizationParamInitializer();
    public static final String GAMMA = "gamma";
    public static final String BETA = "beta";
    public static final String GLOBAL_MEAN = "mean";
    public static final String GLOBAL_VAR = "var";
    public static final String GLOBAL_LOG_STD = "log10stdev";

    public static BatchNormalizationParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override
    public long numParams(NeuralNetConfiguration conf) {
        return this.numParams(conf.getLayer());
    }

    @Override
    public long numParams(Layer l) {
        BatchNormalization layer = (BatchNormalization)l;
        if (layer.isLockGammaBeta()) {
            return 2L * layer.getNOut();
        }
        return 4L * layer.getNOut();
    }

    @Override
    public List<String> paramKeys(Layer layer) {
        if (((BatchNormalization)layer).isUseLogStd()) {
            return Arrays.asList(GAMMA, BETA, GLOBAL_MEAN, GLOBAL_LOG_STD);
        }
        return Arrays.asList(GAMMA, BETA, GLOBAL_MEAN, GLOBAL_VAR);
    }

    @Override
    public List<String> weightKeys(Layer layer) {
        return Collections.emptyList();
    }

    @Override
    public List<String> biasKeys(Layer layer) {
        return Collections.emptyList();
    }

    @Override
    public boolean isWeightParam(Layer layer, String key) {
        return false;
    }

    @Override
    public boolean isBiasParam(Layer layer, String key) {
        return false;
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramView, boolean initializeParams) {
        Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap());
        BatchNormalization layer = (BatchNormalization)conf.getLayer();
        long nOut = layer.getNOut();
        long meanOffset = 0L;
        if (!layer.isLockGammaBeta()) {
            INDArray gammaView = paramView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)0L, (long)nOut)});
            INDArray betaView = paramView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)nOut, (long)(2L * nOut))});
            params.put(GAMMA, this.createGamma(conf, gammaView, initializeParams));
            conf.addVariable(GAMMA);
            params.put(BETA, this.createBeta(conf, betaView, initializeParams));
            conf.addVariable(BETA);
            meanOffset = 2L * nOut;
        }
        INDArray globalMeanView = paramView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)meanOffset, (long)(meanOffset + nOut))});
        INDArray globalVarView = paramView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)(meanOffset + nOut), (long)(meanOffset + 2L * nOut))});
        if (initializeParams) {
            globalMeanView.assign((Number)0);
            if (layer.isUseLogStd()) {
                globalVarView.assign((Number)0);
            } else {
                globalVarView.assign((Number)1);
            }
        }
        params.put(GLOBAL_MEAN, globalMeanView);
        conf.addVariable(GLOBAL_MEAN);
        if (layer.isUseLogStd()) {
            params.put(GLOBAL_LOG_STD, globalVarView);
            conf.addVariable(GLOBAL_LOG_STD);
        } else {
            params.put(GLOBAL_VAR, globalVarView);
            conf.addVariable(GLOBAL_VAR);
        }
        return params;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        BatchNormalization layer = (BatchNormalization)conf.getLayer();
        long nOut = layer.getNOut();
        LinkedHashMap<String, INDArray> out = new LinkedHashMap<String, INDArray>();
        long meanOffset = 0L;
        if (!layer.isLockGammaBeta()) {
            INDArray gammaView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)0L, (long)nOut)});
            INDArray betaView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)nOut, (long)(2L * nOut))});
            out.put(GAMMA, gammaView);
            out.put(BETA, betaView);
            meanOffset = 2L * nOut;
        }
        out.put(GLOBAL_MEAN, gradientView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)meanOffset, (long)(meanOffset + nOut))}));
        if (layer.isUseLogStd()) {
            out.put(GLOBAL_LOG_STD, gradientView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)(meanOffset + nOut), (long)(meanOffset + 2L * nOut))}));
        } else {
            out.put(GLOBAL_VAR, gradientView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)(meanOffset + nOut), (long)(meanOffset + 2L * nOut))}));
        }
        return out;
    }

    private INDArray createBeta(NeuralNetConfiguration conf, INDArray betaView, boolean initializeParams) {
        BatchNormalization layer = (BatchNormalization)conf.getLayer();
        if (initializeParams) {
            betaView.assign((Number)layer.getBeta());
        }
        return betaView;
    }

    private INDArray createGamma(NeuralNetConfiguration conf, INDArray gammaView, boolean initializeParams) {
        BatchNormalization layer = (BatchNormalization)conf.getLayer();
        if (initializeParams) {
            gammaView.assign((Number)layer.getGamma());
        }
        return gammaView;
    }
}

