/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.modelimport.keras.IncompatibleKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.Model;
import org.nd4j.shade.jackson.core.type.TypeReference;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ModelConfiguration {
    private static Logger log = LoggerFactory.getLogger(Model.class);

    private ModelConfiguration() {
    }

    public static MultiLayerConfiguration importSequentialModelConfigFromFile(String configJsonFilename) throws IOException {
        String configJson = new String(Files.readAllBytes(Paths.get(configJsonFilename, new String[0])));
        return ModelConfiguration.importSequentialModelConfig(configJson);
    }

    public static ComputationGraphConfiguration importFunctionalApiConfigFromFile(String configJsonFilename) throws IOException {
        String configJson = new String(Files.readAllBytes(Paths.get(configJsonFilename, new String[0])));
        return ModelConfiguration.importFunctionalApiConfig(configJson);
    }

    public static MultiLayerConfiguration importSequentialModelConfig(String configJson) throws IOException {
        Map<String, Object> kerasConfig = ModelConfiguration.parseJsonString(configJson);
        MultiLayerConfiguration modelConfig = ModelConfiguration.importSequentialModelConfig(kerasConfig);
        return modelConfig;
    }

    public static ComputationGraphConfiguration importFunctionalApiConfig(String configJson) throws IOException {
        Map<String, Object> kerasConfig = ModelConfiguration.parseJsonString(configJson);
        ComputationGraphConfiguration modelConfig = ModelConfiguration.importFunctionalApiConfig(kerasConfig);
        return modelConfig;
    }

    private static MultiLayerConfiguration importSequentialModelConfig(Map<String, Object> kerasConfig) throws IOException, IncompatibleKerasConfigurationException {
        String arch = (String)kerasConfig.get("class_name");
        if (!arch.equals("Sequential")) {
            throw new IncompatibleKerasConfigurationException("Expected \"Sequential\" model config, found " + arch);
        }
        double prevDropout = 0.0;
        ArrayList<Map> layerConfigs = new ArrayList<Map>();
        block8: for (Object o : (List)kerasConfig.get("config")) {
            String kerasLayerName = (String)((Map)o).get("class_name");
            Map layerConfig = (Map)((Map)o).get("config");
            switch (kerasLayerName) {
                case "Dropout": {
                    prevDropout = (Double)layerConfig.get("p");
                    continue block8;
                }
                case "Activation": {
                    if (layerConfigs.size() == 0) {
                        throw new IncompatibleKerasConfigurationException("Plain activation layer applied to input not supported.");
                    }
                    String activation = LayerConfiguration.mapActivation((String)layerConfig.get("activation"));
                    ((Map)layerConfigs.get(layerConfigs.size() - 1)).put("activation", activation);
                    continue block8;
                }
            }
            layerConfig.put("keras_class", kerasLayerName);
            if (prevDropout > 0.0) {
                double oldDropout = layerConfig.containsKey("dropout") ? (Double)layerConfig.get("dropout") : 0.0;
                double newDropout = 1.0 - (1.0 - prevDropout) * (1.0 - oldDropout);
                layerConfig.put("dropout", newDropout);
                if (oldDropout != newDropout) {
                    log.warn("Changed layer-defined dropout " + oldDropout + " to " + newDropout + " because of previous Dropout=" + newDropout + " layer");
                }
                prevDropout = 0.0;
            }
            layerConfigs.add(layerConfig);
        }
        List batchInputShape = null;
        String dimOrdering = null;
        boolean isRecurrent = false;
        boolean isConvolutional = false;
        NeuralNetConfiguration.Builder modelBuilder = new NeuralNetConfiguration.Builder();
        NeuralNetConfiguration.ListBuilder listBuilder = modelBuilder.list();
        int layerIndex = 0;
        for (Map layerConfig : layerConfigs) {
            Layer layer;
            String kerasLayerName = (String)layerConfig.get("keras_class");
            if (layerConfig.containsKey("batch_input_shape")) {
                if (layerIndex > 0) {
                    throw new IncompatibleKerasConfigurationException("Non-input layer should not specify \"batch_input_shape\" field");
                }
                batchInputShape = (List)layerConfig.get("batch_input_shape");
            } else if (layerIndex == 0) {
                throw new IncompatibleKerasConfigurationException("Input layer must specify \"batch_input_shape\" field");
            }
            if (layerConfig.containsKey("dim_ordering")) {
                String layerDimOrdering = (String)layerConfig.get("dim_ordering");
                if (!layerDimOrdering.equals("th") && !layerDimOrdering.equals("tf")) {
                    throw new IncompatibleKerasConfigurationException("Unknown Keras backend: " + layerDimOrdering);
                }
                if (dimOrdering != null && !layerDimOrdering.equals(dimOrdering)) {
                    throw new IncompatibleKerasConfigurationException("Found layers with conflicting Keras backends.");
                }
                dimOrdering = layerDimOrdering;
            }
            if ((layer = LayerConfiguration.buildLayer(kerasLayerName, layerConfig, layerIndex == layerConfigs.size() - 1)) == null) continue;
            if (layer instanceof BaseRecurrentLayer) {
                isRecurrent = true;
            } else if (layer instanceof ConvolutionLayer) {
                isConvolutional = true;
            }
            if (layer.getL1() > 0.0 || layer.getL2() > 0.0) {
                modelBuilder.regularization(true);
            }
            listBuilder.layer(layerIndex, layer);
            ++layerIndex;
        }
        if (isRecurrent && isConvolutional) {
            throw new IncompatibleKerasConfigurationException("Recurrent convolutional architecture not supported.");
        }
        if (isRecurrent) {
            listBuilder.setInputType(InputType.recurrent((int)((Integer)batchInputShape.get(2))));
            if (batchInputShape.get(1) == null) {
                log.warn("Input sequence length must be specified manually for truncated BPTT!");
            } else {
                int sequenceLength = (Integer)batchInputShape.get(1);
                listBuilder.tBPTTForwardLength(sequenceLength).tBPTTBackwardLength(sequenceLength);
            }
        } else if (isConvolutional) {
            int[] imageSize = new int[3];
            if (dimOrdering.equals("tf")) {
                imageSize[0] = (Integer)batchInputShape.get(1);
                imageSize[1] = (Integer)batchInputShape.get(2);
                imageSize[2] = (Integer)batchInputShape.get(3);
            } else if (dimOrdering.equals("th")) {
                imageSize[0] = (Integer)batchInputShape.get(2);
                imageSize[1] = (Integer)batchInputShape.get(3);
                imageSize[2] = (Integer)batchInputShape.get(1);
            } else {
                throw new IncompatibleKerasConfigurationException("Unknown keras backend " + dimOrdering);
            }
            listBuilder.setInputType(InputType.convolutional((int)imageSize[0], (int)imageSize[1], (int)imageSize[2]));
        } else {
            listBuilder.setInputType(InputType.feedForward((int)((Integer)batchInputShape.get(1))));
        }
        return listBuilder.build();
    }

    private static ComputationGraphConfiguration importFunctionalApiConfig(Map<String, Object> kerasConfig) throws IOException, NotImplementedException, IncompatibleKerasConfigurationException {
        throw new NotImplementedException("Import of Keras Functional API model configs not supported.");
    }

    public static Map<String, Object> extractWeightsMetadataFromConfig(String configJson) throws IOException {
        HashMap<String, Object> weightsMetadata = new HashMap<String, Object>();
        ObjectMapper mapper = new ObjectMapper();
        TypeReference<HashMap<String, Object>> typeRef = new TypeReference<HashMap<String, Object>>(){};
        Map kerasConfig = (Map)mapper.readValue(configJson, (TypeReference)typeRef);
        List layers = (List)kerasConfig.get("config");
        for (Map layer : layers) {
            Map layerConfig = (Map)layer.get("config");
            if (!layerConfig.containsKey("dim_ordering") || weightsMetadata.containsKey("keras_backend")) continue;
            weightsMetadata.put("keras_backend", layerConfig.get("dim_ordering"));
        }
        return weightsMetadata;
    }

    private static Map<String, Object> parseJsonString(String json) throws IOException {
        ObjectMapper mapper = new ObjectMapper();
        TypeReference<HashMap<String, Object>> typeRef = new TypeReference<HashMap<String, Object>>(){};
        return (Map)mapper.readValue(json, (TypeReference)typeRef);
    }
}

