/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.layers;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasBatchNormalization
extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasBatchNormalization.class);
    public static final int LAYER_BATCHNORM_MODE_1 = 1;
    public static final int LAYER_BATCHNORM_MODE_2 = 2;
    public static final String LAYER_FIELD_GAMMA_REGULARIZER = "gamma_regularizer";
    public static final String LAYER_FIELD_BETA_REGULARIZER = "beta_regularizer";
    public static final String LAYER_FIELD_MODE = "mode";
    public static final String LAYER_FIELD_AXIS = "axis";
    public static final String LAYER_FIELD_MOMENTUM = "momentum";
    public static final String LAYER_FIELD_EPSILON = "epsilon";
    public static final int NUM_TRAINABLE_PARAMS = 4;
    public static final String PARAM_NAME_GAMMA = "gamma";
    public static final String PARAM_NAME_BETA = "beta";
    public static final String PARAM_NAME_RUNNING_MEAN = "running_mean";
    public static final String PARAM_NAME_RUNNING_STD = "running_std";

    public KerasBatchNormalization(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(layerConfig, true);
    }

    public KerasBatchNormalization(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        super(layerConfig, enforceTrainingConfig);
        this.getGammaRegularizerFromConfig(layerConfig, enforceTrainingConfig);
        this.getBetaRegularizerFromConfig(layerConfig, enforceTrainingConfig);
        int batchNormMode = this.getBatchNormMode(layerConfig, enforceTrainingConfig);
        int batchNormAxis = this.getBatchNormAxis(layerConfig, enforceTrainingConfig);
        this.layer = ((BatchNormalization.Builder)((BatchNormalization.Builder)((BatchNormalization.Builder)new BatchNormalization.Builder().name(this.layerName)).dropOut(this.dropout)).minibatch(true).lockGammaBeta(false).eps(this.getEpsFromConfig(layerConfig)).momentum(this.getMomentumFromConfig(layerConfig))).build();
    }

    public BatchNormalization getBatchNormalizationLayer() {
        return (BatchNormalization)this.layer;
    }

    @Override
    public InputType getOutputType(InputType ... inputType) throws InvalidKerasConfigurationException {
        if (inputType.length > 1) {
            throw new InvalidKerasConfigurationException("Keras BatchNorm layer accepts only one input (received " + inputType.length + ")");
        }
        return this.getBatchNormalizationLayer().getOutputType(-1, inputType[0]);
    }

    @Override
    public int getNumParams() {
        return 4;
    }

    @Override
    public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
        this.weights = new HashMap();
        if (!weights.containsKey(PARAM_NAME_BETA)) {
            throw new InvalidKerasConfigurationException("Parameter beta does not exist in weights");
        }
        this.weights.put(PARAM_NAME_BETA, weights.get(PARAM_NAME_BETA));
        if (!weights.containsKey(PARAM_NAME_GAMMA)) {
            throw new InvalidKerasConfigurationException("Parameter gamma does not exist in weights");
        }
        this.weights.put(PARAM_NAME_GAMMA, weights.get(PARAM_NAME_GAMMA));
        if (!weights.containsKey(PARAM_NAME_RUNNING_MEAN)) {
            throw new InvalidKerasConfigurationException("Parameter running_mean does not exist in weights");
        }
        this.weights.put("mean", weights.get(PARAM_NAME_RUNNING_MEAN));
        if (!weights.containsKey(PARAM_NAME_RUNNING_STD)) {
            throw new InvalidKerasConfigurationException("Parameter running_std does not exist in weights");
        }
        this.weights.put("var", weights.get(PARAM_NAME_RUNNING_STD));
        if (weights.size() > 4) {
            Set<String> paramNames = weights.keySet();
            paramNames.remove(PARAM_NAME_BETA);
            paramNames.remove(PARAM_NAME_GAMMA);
            paramNames.remove(PARAM_NAME_RUNNING_MEAN);
            paramNames.remove(PARAM_NAME_RUNNING_STD);
            String unknownParamNames = paramNames.toString();
            log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1));
        }
    }

    protected double getEpsFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasBatchNormalization.getInnerLayerConfigFromConfig(layerConfig);
        if (!innerConfig.containsKey(LAYER_FIELD_EPSILON)) {
            throw new InvalidKerasConfigurationException("Keras BatchNorm layer config missing epsilon field");
        }
        return (Double)innerConfig.get(LAYER_FIELD_EPSILON);
    }

    protected double getMomentumFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasBatchNormalization.getInnerLayerConfigFromConfig(layerConfig);
        if (!innerConfig.containsKey(LAYER_FIELD_MOMENTUM)) {
            throw new InvalidKerasConfigurationException("Keras BatchNorm layer config missing momentum field");
        }
        return (Double)innerConfig.get(LAYER_FIELD_MOMENTUM);
    }

    protected void getGammaRegularizerFromConfig(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasBatchNormalization.getInnerLayerConfigFromConfig(layerConfig);
        if (innerConfig.get(LAYER_FIELD_GAMMA_REGULARIZER) != null) {
            if (enforceTrainingConfig) {
                throw new UnsupportedKerasConfigurationException("Regularization for BatchNormalization gamma parameter not supported");
            }
            log.warn("Regularization for BatchNormalization gamma parameter not supported...ignoring.");
        }
    }

    protected void getBetaRegularizerFromConfig(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasBatchNormalization.getInnerLayerConfigFromConfig(layerConfig);
        if (innerConfig.get(LAYER_FIELD_BETA_REGULARIZER) != null) {
            if (enforceTrainingConfig) {
                throw new UnsupportedKerasConfigurationException("Regularization for BatchNormalization beta parameter not supported");
            }
            log.warn("Regularization for BatchNormalization beta parameter not supported...ignoring.");
        }
    }

    protected int getBatchNormMode(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> innerConfig = KerasBatchNormalization.getInnerLayerConfigFromConfig(layerConfig);
        if (!innerConfig.containsKey(LAYER_FIELD_MODE)) {
            throw new InvalidKerasConfigurationException("Keras BatchNorm layer config missing mode field");
        }
        int batchNormMode = (Integer)innerConfig.get(LAYER_FIELD_MODE);
        switch (batchNormMode) {
            case 1: {
                throw new UnsupportedKerasConfigurationException("Keras BatchNormalization mode 1 (sample-wise) not supported");
            }
            case 2: {
                throw new UnsupportedKerasConfigurationException("Keras BatchNormalization (per-batch statistics during testing) 2 not supported");
            }
        }
        return batchNormMode;
    }

    protected int getBatchNormAxis(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasBatchNormalization.getInnerLayerConfigFromConfig(layerConfig);
        return (Integer)innerConfig.get(LAYER_FIELD_AXIS);
    }
}

