/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.utils;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang3.StringUtils;
import org.bytedeco.hdf5.Group;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.Deconvolution3D;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.config.KerasModelConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.wrappers.KerasBidirectional;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.preprocessors.ReshapePreprocessor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.core.type.TypeReference;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasModelUtils {
    private static final Logger log = LoggerFactory.getLogger(KerasModelUtils.class);

    public static void setDataFormatIfNeeded(InputPreProcessor inputPreProcessor, KerasLayer currLayer) {
        if (inputPreProcessor instanceof ReshapePreprocessor) {
            ReshapePreprocessor reshapePreprocessor = (ReshapePreprocessor)inputPreProcessor;
            if (currLayer.isLayer() && currLayer.getDimOrder() != null) {
                Layer layer = currLayer.getLayer();
                if (layer instanceof ConvolutionLayer) {
                    ConvolutionLayer convolutionLayer = (ConvolutionLayer)layer;
                    if (convolutionLayer instanceof Convolution3D) {
                        Convolution3D convolution3D = (Convolution3D)convolutionLayer;
                        reshapePreprocessor.setFormat((DataFormat)convolution3D.getDataFormat());
                    } else if (convolutionLayer instanceof Deconvolution3D) {
                        Deconvolution3D deconvolution3D = (Deconvolution3D)convolutionLayer;
                        reshapePreprocessor.setFormat((DataFormat)deconvolution3D.getDataFormat());
                    } else {
                        reshapePreprocessor.setFormat((DataFormat)convolutionLayer.getCnn2dDataFormat());
                    }
                } else if (layer instanceof BaseRecurrentLayer) {
                    BaseRecurrentLayer baseRecurrentLayer = (BaseRecurrentLayer)layer;
                    reshapePreprocessor.setFormat((DataFormat)baseRecurrentLayer.getRnnDataFormat());
                }
            }
        }
    }

    public static Model copyWeightsToModel(Model model, Map<String, KerasLayer> kerasLayers) throws InvalidKerasConfigurationException {
        org.deeplearning4j.nn.api.Layer[] layersFromModel = model instanceof MultiLayerNetwork ? ((MultiLayerNetwork)model).getLayers() : ((ComputationGraph)model).getLayers();
        HashSet<String> layerNames = new HashSet<String>(kerasLayers.keySet());
        for (org.deeplearning4j.nn.api.Layer layer : layersFromModel) {
            String layerName = layer.conf().getLayer().getLayerName();
            if (!kerasLayers.containsKey(layerName)) {
                throw new InvalidKerasConfigurationException("No weights found for layer in model (named " + layerName + ")");
            }
            kerasLayers.get(layerName).copyWeightsToLayer(layer);
            layerNames.remove(layerName);
        }
        for (String layerName : layerNames) {
            if (kerasLayers.get(layerName).getNumParams() <= 0) continue;
            throw new InvalidKerasConfigurationException("Attempting to copy weights for layer not in model (named " + layerName + ")");
        }
        return model;
    }

    public static int determineKerasMajorVersion(Map<String, Object> modelConfig, KerasModelConfiguration config) throws InvalidKerasConfigurationException {
        int kerasMajorVersion;
        if (!modelConfig.containsKey(config.getFieldKerasVersion())) {
            log.warn("Could not read keras version used (no " + config.getFieldKerasVersion() + " field found) \nassuming keras version is 1.0.7 or earlier.");
            kerasMajorVersion = 1;
        } else {
            String kerasVersionString = (String)modelConfig.get(config.getFieldKerasVersion());
            if (Character.isDigit(kerasVersionString.charAt(0))) {
                kerasMajorVersion = Character.getNumericValue(kerasVersionString.charAt(0));
            } else {
                throw new InvalidKerasConfigurationException("Keras version was not readable (" + config.getFieldKerasVersion() + " provided)");
            }
        }
        return kerasMajorVersion;
    }

    public static String determineKerasBackend(Map<String, Object> modelConfig, KerasModelConfiguration config) {
        String kerasBackend = null;
        if (!modelConfig.containsKey(config.getFieldBackend())) {
            log.warn("Could not read keras backend used (no " + config.getFieldBackend() + " field found) \n");
        } else {
            kerasBackend = (String)modelConfig.get(config.getFieldBackend());
        }
        return kerasBackend;
    }

    private static String findParameterName(String parameter, String[] fragmentList) {
        Matcher tfParamNbMatcher;
        Matcher tfSuffixMatcher;
        Matcher layerNameMatcher = Pattern.compile(fragmentList[fragmentList.length - 1]).matcher(parameter);
        String parameterNameFound = layerNameMatcher.replaceFirst("");
        Matcher paramNameMatcher = Pattern.compile("^_(.+)$").matcher(parameterNameFound);
        if (paramNameMatcher.find()) {
            parameterNameFound = paramNameMatcher.group(1);
        }
        if ((tfSuffixMatcher = Pattern.compile(":\\d+?$").matcher(parameterNameFound)).find()) {
            parameterNameFound = tfSuffixMatcher.replaceFirst("");
        }
        if ((tfParamNbMatcher = Pattern.compile("_\\d+$").matcher(parameterNameFound)).find()) {
            parameterNameFound = tfParamNbMatcher.replaceFirst("");
        }
        return parameterNameFound;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static void importWeights(Hdf5Archive weightsArchive, String weightsRoot, Map<String, KerasLayer> layers, int kerasVersion, String backend) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        boolean includesSlash = false;
        for (String layerName : layers.keySet()) {
            if (!layerName.contains("/")) continue;
            includesSlash = true;
        }
        Class<KerasModelUtils> clazz = KerasModelUtils.class;
        synchronized (KerasModelUtils.class) {
            List<String> layerGroups = !includesSlash ? (weightsRoot != null ? weightsArchive.getGroups(weightsRoot) : weightsArchive.getGroups(new String[0])) : new ArrayList<String>(layers.keySet());
            for (String layerName : layerGroups) {
                List<String> layerParamNames;
                String attributeJoinStr;
                if (layerName.equals("top_level_model_weights")) {
                    Object object = Hdf5Archive.LOCK_OBJECT;
                    synchronized (object) {
                        Group[] rootGroup = weightsArchive.openGroups(weightsRoot + "/" + layerName);
                        if (rootGroup[0].getNumObjs() < 1L) {
                            weightsArchive.closeGroups(rootGroup);
                        }
                        continue;
                    }
                }
                String[] layerFragments = layerName.split("/");
                Object rootPrefix = weightsRoot != null ? weightsRoot + "/" : "";
                ArrayList<String> attributeStrParts = new ArrayList<String>();
                String attributeStr = weightsArchive.readAttributeAsString("weight_names", (String)rootPrefix + layerName);
                Matcher attributeMatcher = Pattern.compile(":\\d+").matcher(attributeStr);
                Boolean foundTfGroups = attributeMatcher.find();
                if (foundTfGroups.booleanValue()) {
                    for (String part : attributeStr.split("/")) {
                        Matcher tfSuffixMatcher;
                        if ((part = part.trim()).length() == 0 || (tfSuffixMatcher = Pattern.compile(":\\d+").matcher(part)).find()) break;
                        attributeStrParts.add(part);
                    }
                    attributeJoinStr = StringUtils.join(attributeStrParts, (String)"/");
                } else {
                    attributeJoinStr = layerFragments[0];
                }
                String baseAttributes = layerName + "/" + attributeJoinStr;
                if (layerFragments.length > 1) {
                    try {
                        layerParamNames = weightsArchive.getDataSets((String)rootPrefix + baseAttributes);
                    }
                    catch (Exception e) {
                        layerParamNames = weightsArchive.getDataSets((String)rootPrefix + layerName);
                    }
                } else if (foundTfGroups.booleanValue()) {
                    layerParamNames = weightsArchive.getDataSets((String)rootPrefix + baseAttributes);
                } else if (kerasVersion == 2) {
                    if (backend.equals("theano") && layerName.contains("bidirectional")) {
                        for (String part : attributeStr.split("/")) {
                            if (!part.contains("forward")) continue;
                            baseAttributes = baseAttributes + "/" + part;
                        }
                    }
                    if (layers.get(layerName).getNumParams() > 0) {
                        try {
                            layerParamNames = weightsArchive.getDataSets((String)rootPrefix + baseAttributes);
                        }
                        catch (Exception e) {
                            log.warn("No HDF5 group with weights found for layer with name " + layerName + ", continuing import.");
                            layerParamNames = Collections.emptyList();
                        }
                    } else {
                        layerParamNames = weightsArchive.getDataSets((String)rootPrefix + layerName);
                    }
                } else {
                    layerParamNames = weightsArchive.getDataSets((String)rootPrefix + layerName);
                }
                if (layerParamNames.isEmpty()) continue;
                if (!layers.containsKey(layerName)) {
                    throw new InvalidKerasConfigurationException("Found weights for layer not in model (named " + layerName + ")");
                }
                KerasLayer layer = layers.get(layerName);
                if (layerParamNames.size() != layer.getNumParams() && kerasVersion == 2 && layer instanceof KerasBidirectional && 2 * layerParamNames.size() != layer.getNumParams()) {
                    throw new InvalidKerasConfigurationException("Found " + layerParamNames.size() + " weights for layer with " + layer.getNumParams() + " trainable params (named " + layerName + ")");
                }
                HashMap<String, INDArray> weights = new HashMap<String, INDArray>();
                for (String layerParamName : layerParamNames) {
                    String paramName = KerasModelUtils.findParameterName(layerParamName, layerFragments);
                    if (kerasVersion == 2 && layer instanceof KerasBidirectional) {
                        String backwardAttributes = baseAttributes.replace("forward", "backward");
                        INDArray forwardParamValue = weightsArchive.readDataSet(layerParamName, (String)rootPrefix + baseAttributes);
                        INDArray backwardParamValue = weightsArchive.readDataSet(layerParamName, (String)rootPrefix + backwardAttributes);
                        weights.put("forward_" + paramName, forwardParamValue);
                        weights.put("backward_" + paramName, backwardParamValue);
                        continue;
                    }
                    INDArray paramValue = foundTfGroups != false ? weightsArchive.readDataSet(layerParamName, (String)rootPrefix + baseAttributes) : (layerFragments.length > 1 ? weightsArchive.readDataSet(layerFragments[0] + "/" + layerParamName, new String[]{rootPrefix, layerName}) : (kerasVersion == 2 ? weightsArchive.readDataSet(layerParamName, (String)rootPrefix + baseAttributes) : weightsArchive.readDataSet(layerParamName, new String[]{rootPrefix, layerName})));
                    weights.put(paramName, paramValue);
                }
                layer.setWeights(weights);
            }
            HashSet<String> layerNames = new HashSet<String>(layers.keySet());
            layerNames.removeAll(layerGroups);
            for (String layerName : layerNames) {
                if (layers.get(layerName).getNumParams() <= 0) continue;
                throw new InvalidKerasConfigurationException("Could not find weights required for layer " + layerName);
            }
            // ** MonitorExit[var6_6] (shouldn't be in output)
            return;
        }
    }

    public static Map<String, Object> parseModelConfig(String modelJson, String modelYaml) throws IOException, InvalidKerasConfigurationException {
        Map<String, Object> modelConfig;
        if (modelJson != null) {
            modelConfig = KerasModelUtils.parseJsonString(modelJson);
        } else if (modelYaml != null) {
            modelConfig = KerasModelUtils.parseYamlString(modelYaml);
        } else {
            throw new InvalidKerasConfigurationException("Requires model configuration as either JSON or YAML string.");
        }
        return modelConfig;
    }

    public static Map<String, Object> parseJsonString(String json) throws IOException {
        ObjectMapper mapper = new ObjectMapper();
        TypeReference<HashMap<String, Object>> typeRef = new TypeReference<HashMap<String, Object>>(){};
        return (Map)mapper.readValue(json, (TypeReference)typeRef);
    }

    public static Map<String, Object> parseYamlString(String yaml) throws IOException {
        ObjectMapper mapper = new ObjectMapper((JsonFactory)new YAMLFactory());
        TypeReference<HashMap<String, Object>> typeRef = new TypeReference<HashMap<String, Object>>(){};
        return (Map)mapper.readValue(yaml, (TypeReference)typeRef);
    }
}

