/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasLayer {
    public static final String LAYER_FIELD_CLASS_NAME = "class_name";
    public static final String LAYER_CLASS_NAME_INPUT = "InputLayer";
    public static final String LAYER_CLASS_NAME_ACTIVATION = "Activation";
    public static final String LAYER_CLASS_NAME_DROPOUT = "Dropout";
    public static final String LAYER_CLASS_NAME_DENSE = "Dense";
    public static final String LAYER_CLASS_NAME_TIME_DISTRIBUTED_DENSE = "TimeDistributedDense";
    public static final String LAYER_CLASS_NAME_LSTM = "LSTM";
    public static final String LAYER_CLASS_NAME_CONVOLUTION_2D = "Convolution2D";
    public static final String LAYER_CLASS_NAME_MAX_POOLING_2D = "MaxPooling2D";
    public static final String LAYER_CLASS_NAME_AVERAGE_POOLING_2D = "AveragePooling2D";
    public static final String LAYER_CLASS_NAME_FLATTEN = "Flatten";
    public static final String LAYER_CLASS_NAME_RESHAPE = "Reshape";
    public static final String LAYER_CLASS_NAME_REPEATVECTOR = "RepeatVector";
    public static final String LAYER_CLASS_NAME_MERGE = "Merge";
    public static final String LAYER_CLASS_NAME_BATCHNORMALIZATION = "BatchNormalization";
    public static final String LAYER_FIELD_CONFIG = "config";
    public static final String LAYER_FIELD_NAME = "name";
    public static final String LAYER_FIELD_DROPOUT = "dropout";
    public static final String LAYER_FIELD_OUTPUT_DIM = "output_dim";
    public static final String LAYER_FIELD_SUBSAMPLE = "subsample";
    public static final String LAYER_FIELD_NB_ROW = "nb_row";
    public static final String LAYER_FIELD_NB_COL = "nb_col";
    public static final String LAYER_FIELD_NB_FILTER = "nb_filter";
    public static final String LAYER_FIELD_STRIDES = "strides";
    public static final String LAYER_FIELD_POOL_SIZE = "pool_size";
    public static final String LAYER_FIELD_DROPOUT_U = "dropout_U";
    public static final String LAYER_FIELD_DROPOUT_W = "dropout_W";
    public static final String LAYER_FIELD_BATCH_INPUT_SHAPE = "batch_input_shape";
    public static final String LAYER_FIELD_INBOUND_NODES = "inbound_nodes";
    public static final String LAYER_FIELD_BORDER_MODE = "border_mode";
    public static final String LAYER_FIELD_GAMMA_REGULARIZER = "gamma_regularizer";
    public static final String LAYER_FIELD_BETA_REGULARIZER = "beta_regularizer";
    public static final String LAYER_FIELD_MODE = "mode";
    public static final String LAYER_FIELD_AXIS = "axis";
    public static final String LAYER_FIELD_EPSILON = "epsilon";
    public static final String LAYER_FIELD_MOMENTUM = "momentum";
    public static final String LAYER_BORDER_MODE_SAME = "same";
    public static final String LAYER_BORDER_MODE_VALID = "valid";
    public static final String LAYER_BORDER_MODE_FULL = "full";
    public static final int LAYER_BATCHNORM_MODE_1 = 1;
    public static final int LAYER_BATCHNORM_MODE_2 = 2;
    public static final String LAYER_FIELD_W_REGULARIZER = "W_regularizer";
    public static final String LAYER_FIELD_B_REGULARIZER = "b_regularizer";
    public static final String REGULARIZATION_TYPE_L1 = "l1";
    public static final String REGULARIZATION_TYPE_L2 = "l2";
    public static final String LAYER_FIELD_INIT = "init";
    public static final String LAYER_FIELD_INNER_INIT = "inner_init";
    public static final String INIT_UNIFORM = "uniform";
    public static final String INIT_ZERO = "zero";
    public static final String INIT_GLOROT_NORMAL = "glorot_normal";
    public static final String INIT_GLOROT_UNIFORM = "glorot_uniform";
    public static final String INIT_HE_NORMAL = "he_normal";
    public static final String INIT_HE_UNIFORM = "he_uniform";
    public static final String INIT_LECUN_UNIFORM = "lecun_uniform";
    public static final String INIT_NORMAL = "normal";
    public static final String INIT_ORTHOGONAL = "orthogonal";
    public static final String INIT_IDENTITY = "identity";
    public static final String LAYER_FIELD_ACTIVATION = "activation";
    public static final String LAYER_FIELD_INNER_ACTIVATION = "inner_activation";
    public static final String KERAS_ACTIVATION_LINEAR = "linear";
    public static final String DL4J_ACTIVATION_IDENTITY = "identity";
    public static final String KERAS_ACTIVATION_HARD_SIGMOID = "hard_sigmoid";
    public static final String DL4J_ACTIVATION_HARDSIGMOID = "hardsigmoid";
    public static final String LAYER_FIELD_FORGET_BIAS_INIT = "forget_bias_init";
    public static final String LSTM_FORGET_BIAS_INIT_ZERO = "zero";
    public static final String LSTM_FORGET_BIAS_INIT_ONE = "one";
    public static final String LAYER_FIELD_DIM_ORDERING = "dim_ordering";
    public static final String DIM_ORDERING_THEANO = "th";
    public static final String DIM_ORDERING_TENSORFLOW = "tf";
    public static final String LAYER_CLASS_NAME_LOSS = "Loss";
    public static final String LAYER_FIELD_LOSS = "loss";
    public static final String LOSS_SQUARED_LOSS_1 = "mean_squared_error";
    public static final String KERAS_LOSS_SQUARED_LOSS_2 = "mse";
    public static final String KERAS_LOSS_MEAN_ABSOLUTE_ERROR_1 = "mean_absolute_error";
    public static final String KERAS_LOSS_MEAN_ABSOLUTE_ERROR_2 = "mae";
    public static final String KERAS_LOSS_MEAN_ABSOLUTE_PERCENTAGE_ERROR_1 = "mean_absolute_percentage_error";
    public static final String KERAS_LOSS_MEAN_ABSOLUTE_PERCENTAGE_ERROR_2 = "mape";
    public static final String KERAS_LOSS_MEAN_SQUARED_LOGARITHMIC_ERROR_1 = "mean_squared_logarithmic_error";
    public static final String KERAS_LOSS_MEAN_SQUARED_LOGARITHMIC_ERROR_2 = "msle";
    public static final String KERAS_LOSS_SQUARED_HINGE = "squared_hinge";
    public static final String KERAS_LOSS_HINGE = "hinge";
    public static final String KERAS_LOSS_XENT = "binary_crossentropy";
    public static final String KERAS_LOSS_MCXENT = "categorical_crossentropy";
    public static final String KERAS_LOSS_SP_XE = "sparse_categorical_crossentropy";
    public static final String KERAS_LOSS_KL_DIVERGENCE_1 = "kullback_leibler_divergence";
    public static final String KERAS_LOSS_KL_DIVERGENCE_2 = "kld";
    public static final String KERAS_LOSS_POISSON = "poisson";
    public static final String KERAS_LOSS_COSINE_PROXIMITY = "cosine_proximity";
    private static Logger log = LoggerFactory.getLogger(KerasLayer.class);
    private Map<String, Object> layerConfig;
    private String className;
    private String layerName;
    private DimOrder dimOrder = DimOrder.NONE;
    private int[] inputShape;
    private List<String> inboundLayerNames = new ArrayList<String>();
    private Layer dl4jLayer;
    private boolean train;

    public KerasLayer(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(layerConfig, false);
    }

    public KerasLayer(Map<String, Object> layerConfig, boolean train) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> outerConfig = layerConfig;
        this.className = (String)KerasLayer.checkAndGetField(outerConfig, LAYER_FIELD_CLASS_NAME);
        Map innerConfig = (Map)KerasLayer.checkAndGetField(outerConfig, LAYER_FIELD_CONFIG);
        for (String field : outerConfig.keySet()) {
            if (field.equals(LAYER_FIELD_CONFIG)) continue;
            innerConfig.put(field, outerConfig.get(field));
        }
        this.layerConfig = innerConfig;
        this.train = train;
        this.layerName = (String)KerasLayer.checkAndGetField(this.layerConfig, LAYER_FIELD_NAME);
        this.dl4jLayer = KerasLayer.buildLayerFromConfig(this.layerConfig, this.train);
        this.dimOrder = this.getDimOrderFromConfig(this.layerConfig);
        this.inputShape = this.getInputShapeFromConfig(this.layerConfig, this.dimOrder);
        this.inboundLayerNames = KerasLayer.getInboundLayerNamesFromConfig(this.layerConfig);
    }

    public Map<String, Object> getConfiguration() {
        return this.layerConfig;
    }

    public String getClassName() {
        return this.className;
    }

    public String getName() {
        return this.layerName;
    }

    public DimOrder getDimOrder() {
        return this.dimOrder;
    }

    public int[] getInputShape() {
        return this.inputShape;
    }

    public List<String> getInboundLayerNames() {
        return this.inboundLayerNames;
    }

    public void setInboundLayerNames(List<String> inboundLayerNames) {
        this.inboundLayerNames = new ArrayList<String>(inboundLayerNames);
    }

    public void addInboundLayer(String layer) {
        this.inboundLayerNames.add(layer);
    }

    public boolean getTrain() {
        return this.train;
    }

    public boolean isValidInboundLayer() {
        return this.dl4jLayer != null || this.className.equals(LAYER_CLASS_NAME_INPUT);
    }

    public boolean isDl4jLayer() {
        return this.dl4jLayer != null;
    }

    public Layer getDl4jLayer() {
        return this.dl4jLayer;
    }

    public boolean isDl4jPreprocessor() throws UnsupportedKerasConfigurationException {
        throw new UnsupportedKerasConfigurationException("Conversion from Keras layer to DL4J preprocessor not impemented.");
    }

    public PreprocessorVertex getDl4jPreprocessor() throws UnsupportedKerasConfigurationException {
        throw new UnsupportedKerasConfigurationException("Conversion from Keras layer to DL4J preprocessor not impemented.");
    }

    public static KerasLayer createInputLayer(String layerName, int[] inputShape) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        HashMap<String, Object> config = new HashMap<String, Object>();
        config.put(LAYER_FIELD_NAME, layerName);
        ArrayList<Integer> batchInputShape = new ArrayList<Integer>();
        batchInputShape.add(null);
        for (int i = 0; i < inputShape.length; ++i) {
            batchInputShape.add(inputShape[i]);
        }
        config.put(LAYER_FIELD_BATCH_INPUT_SHAPE, batchInputShape);
        HashMap<String, Object> layerConfig = new HashMap<String, Object>();
        layerConfig.put(LAYER_FIELD_CONFIG, config);
        layerConfig.put(LAYER_FIELD_CLASS_NAME, LAYER_CLASS_NAME_INPUT);
        return new KerasLayer(layerConfig, false);
    }

    public static KerasLayer createLossLayer(String layerName, String kerasLoss) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return KerasLayer.createLossLayer(layerName, kerasLoss, true);
    }

    public static KerasLayer createLossLayer(String layerName, String kerasLoss, boolean train) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        HashMap<String, String> config = new HashMap<String, String>();
        config.put(LAYER_FIELD_NAME, layerName);
        config.put(LAYER_FIELD_LOSS, kerasLoss);
        HashMap<String, Object> layerConfig = new HashMap<String, Object>();
        layerConfig.put(LAYER_FIELD_CONFIG, config);
        layerConfig.put(LAYER_FIELD_CLASS_NAME, LAYER_CLASS_NAME_LOSS);
        return new KerasLayer(layerConfig, train);
    }

    public static Layer buildLayerFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return KerasLayer.buildLayerFromConfig(layerConfig, false);
    }

    public static Layer buildLayerFromConfig(Map<String, Object> layerConfig, boolean train) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        if (!layerConfig.containsKey(LAYER_FIELD_CLASS_NAME)) {
            throw new InvalidKerasConfigurationException("Missing class_name field.");
        }
        String layerClassName = (String)layerConfig.get(LAYER_FIELD_CLASS_NAME);
        ActivationLayer layer = null;
        switch (layerClassName) {
            case "Activation": {
                layer = KerasLayer.buildActivationLayer(layerConfig, train);
                break;
            }
            case "Dropout": {
                layer = KerasLayer.buildDropoutLayer(layerConfig, train);
                break;
            }
            case "Dense": 
            case "TimeDistributedDense": {
                layer = KerasLayer.buildDenseLayer(layerConfig, train);
                break;
            }
            case "LSTM": {
                layer = KerasLayer.buildGravesLstmLayer(layerConfig, train);
                break;
            }
            case "Convolution2D": {
                layer = KerasLayer.buildConvolutionLayer(layerConfig, train);
                break;
            }
            case "MaxPooling2D": 
            case "AveragePooling2D": {
                layer = KerasLayer.buildSubsamplingLayer(layerConfig, train);
                break;
            }
            case "BatchNormalization": {
                layer = KerasLayer.buildBatchNormalizationLayer(layerConfig, train);
                break;
            }
            case "Loss": {
                layer = KerasLayer.buildLossLayer(layerConfig, train);
                break;
            }
            case "Flatten": 
            case "Reshape": 
            case "RepeatVector": 
            case "Merge": 
            case "InputLayer": {
                log.warn("Found Keras " + layerClassName + ". DL4J adds \"preprocessor\" layers during model compilation: https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java#L429");
                break;
            }
            default: {
                throw new InvalidKerasConfigurationException("Unsupported keras layer type " + layerClassName);
            }
        }
        return layer;
    }

    public static String mapActivation(String kerasActivation) {
        String dl4jActivation = null;
        switch (kerasActivation) {
            case "linear": {
                dl4jActivation = "identity";
                break;
            }
            case "hard_sigmoid": {
                dl4jActivation = DL4J_ACTIVATION_HARDSIGMOID;
                break;
            }
            default: {
                dl4jActivation = kerasActivation;
            }
        }
        return dl4jActivation;
    }

    public static WeightInit mapWeightInitialization(String kerasInit) throws UnsupportedKerasConfigurationException {
        WeightInit init = WeightInit.XAVIER;
        if (kerasInit != null) {
            switch (kerasInit) {
                case "glorot_normal": {
                    init = WeightInit.XAVIER;
                    break;
                }
                case "glorot_uniform": {
                    init = WeightInit.XAVIER_UNIFORM;
                    break;
                }
                case "he_normal": {
                    init = WeightInit.RELU;
                    break;
                }
                case "he_uniform": {
                    init = WeightInit.RELU_UNIFORM;
                    break;
                }
                case "zero": {
                    init = WeightInit.ZERO;
                    break;
                }
                default: {
                    throw new UnsupportedKerasConfigurationException("Unknown keras weight initializer " + init);
                }
            }
        }
        return init;
    }

    public static LossFunctions.LossFunction mapLossFunction(String kerasLoss) throws UnsupportedKerasConfigurationException {
        LossFunctions.LossFunction dl4jLoss = LossFunctions.LossFunction.SQUARED_LOSS;
        switch (kerasLoss) {
            case "mean_squared_error": 
            case "mse": {
                dl4jLoss = LossFunctions.LossFunction.SQUARED_LOSS;
                break;
            }
            case "mean_absolute_error": 
            case "mae": {
                dl4jLoss = LossFunctions.LossFunction.MEAN_ABSOLUTE_ERROR;
                break;
            }
            case "mean_absolute_percentage_error": 
            case "mape": {
                dl4jLoss = LossFunctions.LossFunction.MEAN_ABSOLUTE_PERCENTAGE_ERROR;
                break;
            }
            case "mean_squared_logarithmic_error": 
            case "msle": {
                dl4jLoss = LossFunctions.LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR;
                break;
            }
            case "squared_hinge": {
                dl4jLoss = LossFunctions.LossFunction.SQUARED_HINGE;
                break;
            }
            case "hinge": {
                dl4jLoss = LossFunctions.LossFunction.HINGE;
                break;
            }
            case "binary_crossentropy": {
                dl4jLoss = LossFunctions.LossFunction.XENT;
                break;
            }
            case "sparse_categorical_crossentropy": {
                log.warn("Sparse cross entropy not implemented, using multiclass cross entropy instead.");
            }
            case "categorical_crossentropy": {
                dl4jLoss = LossFunctions.LossFunction.MCXENT;
                break;
            }
            case "kullback_leibler_divergence": 
            case "kld": {
                dl4jLoss = LossFunctions.LossFunction.KL_DIVERGENCE;
                break;
            }
            case "poisson": {
                dl4jLoss = LossFunctions.LossFunction.POISSON;
                break;
            }
            case "cosine_proximity": {
                dl4jLoss = LossFunctions.LossFunction.COSINE_PROXIMITY;
                break;
            }
            default: {
                throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss);
            }
        }
        return dl4jLoss;
    }

    private DimOrder getDimOrderFromConfig(Map<String, Object> layerConfig) {
        DimOrder dimOrder = DimOrder.NONE;
        if (layerConfig.containsKey(LAYER_FIELD_DIM_ORDERING)) {
            String dimOrderStr;
            switch (dimOrderStr = (String)layerConfig.get(LAYER_FIELD_DIM_ORDERING)) {
                case "tf": {
                    dimOrder = DimOrder.TENSORFLOW;
                    break;
                }
                case "th": {
                    dimOrder = DimOrder.THEANO;
                    break;
                }
                default: {
                    log.warn("Keras layer has unknown Keras dimension order: " + (Object)((Object)dimOrder));
                }
            }
        }
        return dimOrder;
    }

    private int[] getInputShapeFromConfig(Map<String, Object> layerConfig, DimOrder dimOrder) {
        if (!layerConfig.containsKey(LAYER_FIELD_BATCH_INPUT_SHAPE)) {
            return null;
        }
        List batchInputShape = (List)layerConfig.get(LAYER_FIELD_BATCH_INPUT_SHAPE);
        int[] inputShape = new int[batchInputShape.size() - 1];
        for (int i = 1; i < batchInputShape.size(); ++i) {
            inputShape[i - 1] = batchInputShape.get(i) != null ? (Integer)batchInputShape.get(i) : 0;
        }
        if (dimOrder == DimOrder.THEANO && inputShape.length == 3 && this.dl4jLayer instanceof ConvolutionLayer) {
            int numChannels = inputShape[0];
            inputShape[0] = inputShape[1];
            inputShape[1] = inputShape[2];
            inputShape[2] = numChannels;
        }
        return inputShape;
    }

    private static List<String> getInboundLayerNamesFromConfig(Map<String, Object> layerConfig) {
        List inboundNodes;
        ArrayList<String> inboundNodeNames = new ArrayList<String>();
        if (layerConfig.containsKey(LAYER_FIELD_INBOUND_NODES) && (inboundNodes = (List)layerConfig.get(LAYER_FIELD_INBOUND_NODES)).size() > 0) {
            inboundNodes = (List)inboundNodes.get(0);
            for (Object o : inboundNodes) {
                String nodeName = (String)((List)o).get(0);
                inboundNodeNames.add(nodeName);
            }
        }
        return inboundNodeNames;
    }

    private static double getL1Regularization(Map<String, Object> regularizerConfig) {
        if (regularizerConfig != null && regularizerConfig.containsKey(REGULARIZATION_TYPE_L1)) {
            return (Double)regularizerConfig.get(REGULARIZATION_TYPE_L1);
        }
        return 0.0;
    }

    private static double getL2Regularization(Map<String, Object> regularizerConfig) {
        if (regularizerConfig != null && regularizerConfig.containsKey(REGULARIZATION_TYPE_L2)) {
            return (Double)regularizerConfig.get(REGULARIZATION_TYPE_L2);
        }
        return 0.0;
    }

    private static void checkForUnknownRegularizer(Map<String, Object> regularizerConfig, boolean train) throws UnsupportedKerasConfigurationException {
        if (regularizerConfig != null) {
            Set<String> regularizerFields = regularizerConfig.keySet();
            regularizerFields.remove(REGULARIZATION_TYPE_L1);
            regularizerFields.remove(REGULARIZATION_TYPE_L2);
            regularizerFields.remove(LAYER_FIELD_NAME);
            if (regularizerFields.size() > 0) {
                String unknownField = (String)regularizerFields.toArray()[0];
                if (train) {
                    throw new UnsupportedKerasConfigurationException("Unknown regularization field " + unknownField);
                }
                log.warn("Ignoring unknown regularization field " + unknownField);
            }
        }
    }

    private static ActivationLayer buildActivationLayer(Map<String, Object> layerConfig, boolean train) throws UnsupportedKerasConfigurationException {
        ActivationLayer.Builder builder = new ActivationLayer.Builder();
        KerasLayer.finishLayerConfig((Layer.Builder)builder, layerConfig, train);
        return builder.build();
    }

    private static DropoutLayer buildDropoutLayer(Map<String, Object> layerConfig, boolean train) throws UnsupportedKerasConfigurationException {
        DropoutLayer.Builder builder = new DropoutLayer.Builder();
        KerasLayer.finishLayerConfig((Layer.Builder)builder, layerConfig, train);
        return builder.build();
    }

    private static DenseLayer buildDenseLayer(Map<String, Object> layerConfig, boolean train) throws UnsupportedKerasConfigurationException {
        DenseLayer.Builder builder = (DenseLayer.Builder)new DenseLayer.Builder().nOut(((Integer)layerConfig.get(LAYER_FIELD_OUTPUT_DIM)).intValue());
        KerasLayer.finishLayerConfig((Layer.Builder)builder, layerConfig, train);
        return builder.build();
    }

    private static ConvolutionLayer buildConvolutionLayer(Map<String, Object> layerConfig, boolean train) throws UnsupportedKerasConfigurationException {
        List stride = (List)layerConfig.get(LAYER_FIELD_SUBSAMPLE);
        int nb_row = (Integer)layerConfig.get(LAYER_FIELD_NB_ROW);
        int nb_col = (Integer)layerConfig.get(LAYER_FIELD_NB_COL);
        String borderMode = (String)layerConfig.get(LAYER_FIELD_BORDER_MODE);
        ConvolutionLayer.Builder builder = (ConvolutionLayer.Builder)new ConvolutionLayer.Builder().stride(new int[]{(Integer)stride.get(0), (Integer)stride.get(1)}).kernelSize(new int[]{nb_row, nb_col}).nOut(((Integer)layerConfig.get(LAYER_FIELD_NB_FILTER)).intValue());
        switch (borderMode) {
            case "same": {
                builder.convolutionMode(ConvolutionMode.Same);
                break;
            }
            case "valid": {
                builder.convolutionMode(ConvolutionMode.Truncate);
                break;
            }
            case "full": {
                int[] padding = new int[]{nb_row - 1, nb_col - 1};
                builder.convolutionMode(ConvolutionMode.Truncate).padding(padding);
            }
        }
        KerasLayer.finishLayerConfig((Layer.Builder)builder, layerConfig, train);
        return builder.build();
    }

    private static SubsamplingLayer buildSubsamplingLayer(Map<String, Object> layerConfig, boolean train) throws UnsupportedKerasConfigurationException {
        String borderMode;
        String layerClassName;
        List stride = (List)layerConfig.get(LAYER_FIELD_STRIDES);
        List pool = (List)layerConfig.get(LAYER_FIELD_POOL_SIZE);
        SubsamplingLayer.Builder builder = new SubsamplingLayer.Builder().stride(new int[]{(Integer)stride.get(0), (Integer)stride.get(1)}).kernelSize(new int[]{(Integer)pool.get(0), (Integer)pool.get(1)});
        switch (layerClassName = (String)layerConfig.get(LAYER_FIELD_CLASS_NAME)) {
            case "MaxPooling2D": {
                builder.poolingType(SubsamplingLayer.PoolingType.MAX);
                break;
            }
            case "AveragePooling2D": {
                builder.poolingType(SubsamplingLayer.PoolingType.AVG);
                break;
            }
            default: {
                throw new UnsupportedKerasConfigurationException("Unsupported Keras pooling layer " + layerClassName);
            }
        }
        switch (borderMode = (String)layerConfig.get(LAYER_FIELD_BORDER_MODE)) {
            case "same": {
                builder.convolutionMode(ConvolutionMode.Same);
                break;
            }
            case "valid": {
                builder.convolutionMode(ConvolutionMode.Truncate);
                break;
            }
            case "full": {
                int[] padding = new int[]{(Integer)pool.get(0) - 1, (Integer)pool.get(1) - 1};
                builder.convolutionMode(ConvolutionMode.Truncate).padding(padding);
            }
        }
        KerasLayer.finishLayerConfig((Layer.Builder)builder, layerConfig, train);
        return builder.build();
    }

    private static GravesLSTM buildGravesLstmLayer(Map<String, Object> layerConfig, boolean train) throws UnsupportedKerasConfigurationException {
        String forgetBiasInit;
        if (!layerConfig.get(LAYER_FIELD_INIT).equals(layerConfig.get(LAYER_FIELD_INNER_INIT))) {
            if (train) {
                throw new UnsupportedKerasConfigurationException("Specifying different initialization for LSTM inner cells not supported.");
            }
            log.warn("Specifying different initialization for LSTM inner cells not supported.");
        }
        if ((Double)layerConfig.get(LAYER_FIELD_DROPOUT_U) > 0.0) {
            throw new UnsupportedKerasConfigurationException("Dropout > 0 on LSTM recurrent connections not supported.");
        }
        GravesLSTM.Builder builder = new GravesLSTM.Builder();
        builder.nOut(((Integer)layerConfig.get(LAYER_FIELD_OUTPUT_DIM)).intValue());
        builder.gateActivationFunction(KerasLayer.mapActivation((String)layerConfig.get(LAYER_FIELD_INNER_ACTIVATION)));
        switch (forgetBiasInit = (String)layerConfig.get(LAYER_FIELD_FORGET_BIAS_INIT)) {
            case "zero": {
                builder.forgetGateBiasInit(0.0);
                break;
            }
            case "one": {
                builder.forgetGateBiasInit(1.0);
                break;
            }
            default: {
                if (train) {
                    throw new UnsupportedKerasConfigurationException("Unsupported bias initialization: " + forgetBiasInit);
                }
                builder.forgetGateBiasInit(1.0);
                log.warn("Unsupported bias initialization: " + forgetBiasInit + ". Using ONE instead");
            }
        }
        layerConfig.put(LAYER_FIELD_DROPOUT, (double)((Double)layerConfig.get(LAYER_FIELD_DROPOUT_W)));
        KerasLayer.finishLayerConfig((Layer.Builder)builder, layerConfig, train);
        return builder.build();
    }

    private static BatchNormalization buildBatchNormalizationLayer(Map<String, Object> layerConfig, boolean train) throws UnsupportedKerasConfigurationException {
        if (train) {
            if (layerConfig.get(LAYER_FIELD_GAMMA_REGULARIZER) != null) {
                throw new UnsupportedKerasConfigurationException("Regularization for BatchNormalization gamma parameter not supported");
            }
            log.warn("Regularization for BatchNormalization gamma parameter not supported...ignoring.");
        }
        if (train) {
            if (layerConfig.get(LAYER_FIELD_BETA_REGULARIZER) != null) {
                throw new UnsupportedKerasConfigurationException("Regularization for BatchNormalization beta parameter not supported");
            }
            log.warn("Regularization for BatchNormalization beta parameter not supported...ignoring.");
        }
        int batchNormMode = (Integer)layerConfig.get(LAYER_FIELD_MODE);
        switch (batchNormMode) {
            case 1: {
                throw new UnsupportedKerasConfigurationException("Keras BatchNormalization mode 1 (sample-wise) not supported");
            }
            case 2: {
                throw new UnsupportedKerasConfigurationException("Keras BatchNormalization (per-batch statistics during testing) 2 not supported");
            }
        }
        int axis = (Integer)layerConfig.get(LAYER_FIELD_AXIS);
        log.warn("Ignoring BatchNormalization axis=" + axis + " config. DL4J BatchNormalization defaults to the \"channels\" axis");
        BatchNormalization.Builder builder = new BatchNormalization.Builder();
        builder.eps(((Double)layerConfig.get(LAYER_FIELD_EPSILON)).doubleValue()).momentum(((Double)layerConfig.get(LAYER_FIELD_MOMENTUM)).doubleValue());
        KerasLayer.finishLayerConfig((Layer.Builder)builder, layerConfig, train);
        return builder.build();
    }

    private static LossLayer buildLossLayer(Map<String, Object> layerConfig, boolean train) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        LossFunctions.LossFunction loss;
        String kerasLoss = (String)KerasLayer.checkAndGetField(layerConfig, LAYER_FIELD_LOSS);
        try {
            loss = KerasLayer.mapLossFunction(kerasLoss);
        }
        catch (UnsupportedKerasConfigurationException e) {
            if (train) {
                throw e;
            }
            log.warn("Unsupported Keras loss function. Replacing with MSE.");
            loss = LossFunctions.LossFunction.SQUARED_LOSS;
        }
        LossLayer.Builder builder = new LossLayer.Builder(loss);
        KerasLayer.finishLayerConfig((Layer.Builder)builder, layerConfig, train);
        return builder.build();
    }

    private static Layer.Builder finishLayerConfig(Layer.Builder builder, Map<String, Object> layerConfig, boolean train) throws UnsupportedKerasConfigurationException {
        double l2;
        Map regularizerConfig;
        if (layerConfig.containsKey(LAYER_FIELD_DROPOUT)) {
            builder.dropOut(1.0 - (Double)layerConfig.get(LAYER_FIELD_DROPOUT));
        }
        if (layerConfig.containsKey(LAYER_FIELD_ACTIVATION)) {
            builder.activation(KerasLayer.mapActivation((String)layerConfig.get(LAYER_FIELD_ACTIVATION)));
        }
        builder.name((String)layerConfig.get(LAYER_FIELD_NAME));
        if (layerConfig.containsKey(LAYER_FIELD_INIT)) {
            WeightInit init;
            String kerasInit = (String)layerConfig.get(LAYER_FIELD_INIT);
            try {
                init = KerasLayer.mapWeightInitialization(kerasInit);
            }
            catch (UnsupportedKerasConfigurationException e) {
                if (train) {
                    throw e;
                }
                init = WeightInit.XAVIER;
                log.warn("Unknown weight initializer " + kerasInit + " (Using XAVIER instead).");
            }
            builder.weightInit(init);
            if (init == WeightInit.ZERO) {
                builder.biasInit(0.0);
            }
        }
        if (layerConfig.containsKey(LAYER_FIELD_W_REGULARIZER)) {
            regularizerConfig = (Map)layerConfig.get(LAYER_FIELD_W_REGULARIZER);
            double l1 = KerasLayer.getL1Regularization(regularizerConfig);
            if (l1 > 0.0) {
                builder.l1(l1);
            }
            if ((l2 = KerasLayer.getL2Regularization(regularizerConfig)) > 0.0) {
                builder.l2(l2);
            }
            KerasLayer.checkForUnknownRegularizer(regularizerConfig, train);
        }
        if (layerConfig.containsKey(LAYER_FIELD_B_REGULARIZER)) {
            regularizerConfig = (Map)layerConfig.get(LAYER_FIELD_B_REGULARIZER);
            double l1 = KerasLayer.getL1Regularization(regularizerConfig);
            l2 = KerasLayer.getL2Regularization(regularizerConfig);
            if (l1 > 0.0 || l2 > 0.0) {
                if (train) {
                    throw new UnsupportedKerasConfigurationException("Bias regularization not implemented");
                }
                log.warn("Bias regularization not supported. Ignoring.");
            }
        }
        return builder;
    }

    private static Object checkAndGetField(Map<String, Object> map, String key) throws InvalidKerasConfigurationException {
        if (!map.containsKey(key)) {
            throw new InvalidKerasConfigurationException("Field " + key + " missing from layer config");
        }
        return map.get(key);
    }

    public static enum DimOrder {
        NONE,
        THEANO,
        TENSORFLOW,
        UNKNOWN;

    }
}

