/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras;

import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.modelimport.keras.IncompatibleKerasConfigurationException;
import org.deeplearning4j.nn.weights.WeightInit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LayerConfiguration {
    public static final String KERAS_REGULARIZATION_TYPE_L1 = "l1";
    public static final String KERAS_REGULARIZATION_TYPE_L2 = "l2";
    public static final String KERAS_LAYER_PROPERTY_NAME = "name";
    public static final String KERAS_LAYER_PROPERTY_DROPOUT = "dropout";
    public static final String KERAS_LAYER_PROPERTY_ACTIVATION = "activation";
    public static final String KERAS_LAYER_PROPERTY_INIT = "init";
    public static final String KERAS_LAYER_PROPERTY_W_REGULARIZER = "W_regularizer";
    public static final String KERAS_LAYER_PROPERTY_B_REGULARIZER = "b_regularizer";
    public static final String KERAS_LAYER_PROPERTY_OUTPUT_DIM = "output_dim";
    public static final String KERAS_LAYER_PROPERTY_SUBSAMPLE = "subsample";
    public static final String KERAS_LAYER_PROPERTY_NB_ROW = "nb_row";
    public static final String KERAS_LAYER_PROPERTY_NB_COL = "nb_col";
    public static final String KERAS_LAYER_PROPERTY_NB_FILTER = "nb_filter";
    public static final String KERAS_LAYER_PROPERTY_STRIDES = "strides";
    public static final String KERAS_LAYER_PROPERTY_POOL_SIZE = "pool_size";
    public static final String KERAS_MODEL_PROPERTY_CLASS = "keras_class";
    public static final String KERAS_LAYER_PROPERTY_INNER_ACTIVATION = "inner_activation";
    public static final String KERAS_LAYER_PROPERTY_INNER_INIT = "inner_init";
    public static final String KERAS_LAYER_PROPERTY_DROPOUT_U = "dropout_U";
    public static final String KERAS_LAYER_PROPERTY_FORGET_BIAS_INIT = "forget_bias_init";
    public static final String KERAS_LAYER_PROPERTY_DROPOUT_W = "dropout_W";
    public static final String KERAS_ACTIVATION_LINEAR = "linear";
    public static final String DL4J_ACTIVATION_IDENTITY = "identity";
    public static final String KERAS_LAYER_DENSE = "Dense";
    public static final String KERAS_LAYER_TIME_DISTRIBUTED_DENSE = "TimeDistributedDense";
    public static final String KERAS_LAYER_LSTM = "LSTM";
    public static final String KERAS_LAYER_CONVOLUTION_2D = "Convolution2D";
    public static final String KERAS_LAYER_MAX_POOLING_2D = "MaxPooling2D";
    public static final String KERAS_LAYER_FLATTEN = "Flatten";
    public static final String KERAS_INIT_UNIFORM = "uniform";
    public static final String KERAS_INIT_ZERO = "zero";
    public static final String KERAS_INIT_GLOROT_NORMAL = "glorot_normal";
    public static final String KERAS_INIT_GLOROT_UNIFORM = "glorot_uniform";
    public static final String KERAS_INIT_HE_NORMAL = "he_normal";
    public static final String KERAS_INIT_HE_UNIFORM = "he_uniform";
    public static final String KERAS_INIT_LECUN_UNIFORM = "lecun_uniform";
    public static final String KERAS_INIT_NORMAL = "normal";
    public static final String KERAS_INIT_ORTHOGONAL = "orthogonal";
    public static final String KERAS_INIT_IDENTITY = "identity";
    public static final String KERAS_FORGET_BIAS_ZERO = "zero";
    public static final String KERAS_FORGET_BIAS_ONE = "one";
    private static Logger log = LoggerFactory.getLogger(LayerConfiguration.class);

    private LayerConfiguration() {
    }

    public static Layer buildLayer(String kerasLayerClass, Map<String, Object> kerasConfig) {
        return LayerConfiguration.buildLayer(kerasLayerClass, kerasConfig, false);
    }

    public static Layer buildLayer(String kerasLayerClass, Map<String, Object> kerasConfig, boolean isOutput) {
        DenseLayer layer = null;
        switch (kerasLayerClass) {
            case "Dense": 
            case "TimeDistributedDense": {
                layer = LayerConfiguration.buildDenseLayer(kerasConfig);
                break;
            }
            case "LSTM": {
                layer = LayerConfiguration.buildGravesLstmLayer(kerasConfig);
                break;
            }
            case "Convolution2D": {
                layer = LayerConfiguration.buildConvolutionLayer(kerasConfig);
                break;
            }
            case "MaxPooling2D": {
                layer = LayerConfiguration.buildSubsamplingLayer(kerasConfig);
                break;
            }
            case "Flatten": {
                log.warn("DL4J adds reshaping layers during model compilation: https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java#L429");
                break;
            }
            default: {
                throw new IncompatibleKerasConfigurationException("Unsupported keras layer type " + kerasLayerClass);
            }
        }
        return layer;
    }

    public static String mapActivation(String kerasActivation) {
        if (kerasActivation.equals(KERAS_ACTIVATION_LINEAR)) {
            return "identity";
        }
        return kerasActivation;
    }

    public static WeightInit mapWeightInitialization(String kerasInit) {
        WeightInit init = WeightInit.XAVIER;
        if (kerasInit != null) {
            switch (kerasInit) {
                case "uniform": {
                    init = WeightInit.UNIFORM;
                    break;
                }
                case "zero": {
                    init = WeightInit.ZERO;
                    break;
                }
                case "glorot_normal": {
                    init = WeightInit.XAVIER;
                    break;
                }
                case "glorot_uniform": {
                    init = WeightInit.XAVIER_UNIFORM;
                    break;
                }
                case "he_normal": {
                    init = WeightInit.RELU;
                    break;
                }
                case "he_uniform": {
                    init = WeightInit.RELU_UNIFORM;
                    break;
                }
                default: {
                    log.warn("Unknown keras weight distribution " + init);
                }
            }
        }
        return init;
    }

    public static double getL1Regularization(Map<String, Object> regularizerConfig) {
        if (regularizerConfig != null && regularizerConfig.containsKey(KERAS_REGULARIZATION_TYPE_L1)) {
            return (Double)regularizerConfig.get(KERAS_REGULARIZATION_TYPE_L1);
        }
        return 0.0;
    }

    public static double getL2Regularization(Map<String, Object> regularizerConfig) {
        if (regularizerConfig != null && regularizerConfig.containsKey(KERAS_REGULARIZATION_TYPE_L2)) {
            return (Double)regularizerConfig.get(KERAS_REGULARIZATION_TYPE_L2);
        }
        return 0.0;
    }

    public static void checkForUnknownRegularizer(Map<String, Object> regularizerConfig) {
        if (regularizerConfig != null) {
            Set<String> regularizerFields = regularizerConfig.keySet();
            regularizerFields.remove(KERAS_REGULARIZATION_TYPE_L1);
            regularizerFields.remove(KERAS_REGULARIZATION_TYPE_L2);
            regularizerFields.remove(KERAS_LAYER_PROPERTY_NAME);
            if (regularizerFields.size() > 0) {
                String unknownField = (String)regularizerFields.toArray()[0];
                log.warn("Unknown regularization field: " + unknownField);
            }
        }
    }

    public static Layer.Builder finishLayerConfig(Layer.Builder builder, Map<String, Object> kerasConfig) throws NotImplementedException {
        double l2;
        double l1;
        Map regularizerConfig;
        if (kerasConfig.containsKey(KERAS_LAYER_PROPERTY_DROPOUT)) {
            builder.dropOut(((Double)kerasConfig.get(KERAS_LAYER_PROPERTY_DROPOUT)).doubleValue());
        }
        if (kerasConfig.containsKey(KERAS_LAYER_PROPERTY_ACTIVATION)) {
            builder.activation(LayerConfiguration.mapActivation((String)kerasConfig.get(KERAS_LAYER_PROPERTY_ACTIVATION)));
        }
        builder.name((String)kerasConfig.get(KERAS_LAYER_PROPERTY_NAME));
        if (kerasConfig.containsKey(KERAS_LAYER_PROPERTY_INIT)) {
            WeightInit init = LayerConfiguration.mapWeightInitialization((String)kerasConfig.get(KERAS_LAYER_PROPERTY_INIT));
            builder.weightInit(init);
            if (init == WeightInit.ZERO) {
                builder.biasInit(0.0);
            }
        }
        if (kerasConfig.containsKey(KERAS_LAYER_PROPERTY_W_REGULARIZER)) {
            regularizerConfig = (Map)kerasConfig.get(KERAS_LAYER_PROPERTY_W_REGULARIZER);
            l1 = LayerConfiguration.getL1Regularization(regularizerConfig);
            if (l1 > 0.0) {
                builder.l1(l1);
            }
            if ((l2 = LayerConfiguration.getL2Regularization(regularizerConfig)) > 0.0) {
                builder.l2(l2);
            }
            LayerConfiguration.checkForUnknownRegularizer(regularizerConfig);
        }
        if (kerasConfig.containsKey(KERAS_LAYER_PROPERTY_B_REGULARIZER)) {
            regularizerConfig = (Map)kerasConfig.get(KERAS_LAYER_PROPERTY_B_REGULARIZER);
            l1 = LayerConfiguration.getL1Regularization(regularizerConfig);
            l2 = LayerConfiguration.getL2Regularization(regularizerConfig);
            if (l1 > 0.0 || l2 > 0.0) {
                throw new NotImplementedException("Bias regularization not implemented");
            }
        }
        return builder;
    }

    public static DenseLayer buildDenseLayer(Map<String, Object> kerasConfig) throws NotImplementedException {
        DenseLayer.Builder builder = (DenseLayer.Builder)new DenseLayer.Builder().nOut(((Integer)kerasConfig.get(KERAS_LAYER_PROPERTY_OUTPUT_DIM)).intValue());
        LayerConfiguration.finishLayerConfig((Layer.Builder)builder, kerasConfig);
        return builder.build();
    }

    public static ConvolutionLayer buildConvolutionLayer(Map<String, Object> kerasConfig) throws NotImplementedException {
        List stride = (List)kerasConfig.get(KERAS_LAYER_PROPERTY_SUBSAMPLE);
        int nb_row = (Integer)kerasConfig.get(KERAS_LAYER_PROPERTY_NB_ROW);
        int nb_col = (Integer)kerasConfig.get(KERAS_LAYER_PROPERTY_NB_COL);
        ConvolutionLayer.Builder builder = (ConvolutionLayer.Builder)new ConvolutionLayer.Builder().stride(new int[]{(Integer)stride.get(0), (Integer)stride.get(1)}).kernelSize(new int[]{nb_row, nb_col}).nOut(((Integer)kerasConfig.get(KERAS_LAYER_PROPERTY_NB_FILTER)).intValue());
        LayerConfiguration.finishLayerConfig((Layer.Builder)builder, kerasConfig);
        return builder.build();
    }

    public static SubsamplingLayer buildSubsamplingLayer(Map<String, Object> kerasConfig) throws NotImplementedException {
        List stride = (List)kerasConfig.get(KERAS_LAYER_PROPERTY_STRIDES);
        List pool = (List)kerasConfig.get(KERAS_LAYER_PROPERTY_POOL_SIZE);
        SubsamplingLayer.Builder builder = new SubsamplingLayer.Builder().stride(new int[]{(Integer)stride.get(0), (Integer)stride.get(1)}).kernelSize(new int[]{(Integer)pool.get(0), (Integer)pool.get(1)});
        switch ((String)kerasConfig.get(KERAS_MODEL_PROPERTY_CLASS)) {
            case "MaxPooling2D": {
                builder.poolingType(SubsamplingLayer.PoolingType.MAX);
                break;
            }
            default: {
                throw new NotImplementedException("Other pooling types and shapes not supported.");
            }
        }
        LayerConfiguration.finishLayerConfig((Layer.Builder)builder, kerasConfig);
        return builder.build();
    }

    public static GravesLSTM buildGravesLstmLayer(Map<String, Object> kerasConfig) throws IncompatibleKerasConfigurationException, NotImplementedException {
        String forgetBiasInit;
        if (!kerasConfig.get(KERAS_LAYER_PROPERTY_ACTIVATION).equals(kerasConfig.get(KERAS_LAYER_PROPERTY_INNER_ACTIVATION))) {
            throw new IncompatibleKerasConfigurationException("Specifying different activation for LSTM inner cells not supported.");
        }
        if (!kerasConfig.get(KERAS_LAYER_PROPERTY_INIT).equals(kerasConfig.get(KERAS_LAYER_PROPERTY_INNER_INIT))) {
            log.warn("Specifying different initialization for inner cells not supported.");
        }
        if ((Double)kerasConfig.get(KERAS_LAYER_PROPERTY_DROPOUT_U) > 0.0) {
            throw new IncompatibleKerasConfigurationException("Dropout > 0 on LSTM recurrent connections not supported.");
        }
        GravesLSTM.Builder builder = new GravesLSTM.Builder();
        builder.nOut(((Integer)kerasConfig.get(KERAS_LAYER_PROPERTY_OUTPUT_DIM)).intValue());
        switch (forgetBiasInit = (String)kerasConfig.get(KERAS_LAYER_PROPERTY_FORGET_BIAS_INIT)) {
            case "zero": {
                builder.forgetGateBiasInit(0.0);
                break;
            }
            case "one": {
                builder.forgetGateBiasInit(1.0);
                break;
            }
            default: {
                log.warn("Unsupported bias initialization: " + forgetBiasInit + ".");
            }
        }
        kerasConfig.put(KERAS_LAYER_PROPERTY_DROPOUT, (double)((Double)kerasConfig.get(KERAS_LAYER_PROPERTY_DROPOUT_W)));
        LayerConfiguration.finishLayerConfig((Layer.Builder)builder, kerasConfig);
        return builder.build();
    }
}

