/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasModelConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasLoss;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasLstm;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasSimpleRnn;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasModel {
    private static final Logger log = LoggerFactory.getLogger(KerasModel.class);
    protected static KerasModelConfiguration config = new KerasModelConfiguration();
    KerasModelBuilder modelBuilder = new KerasModelBuilder(config);
    protected String className;
    protected boolean enforceTrainingConfig;
    protected Map<String, KerasLayer> layers;
    List<KerasLayer> layersOrdered;
    Map<String, InputType> outputTypes;
    ArrayList<String> inputLayerNames;
    ArrayList<String> outputLayerNames;
    boolean useTruncatedBPTT = false;
    int truncatedBPTT = 0;
    int kerasMajorVersion;
    String kerasBackend;

    public KerasModel() {
    }

    public KerasModelBuilder modelBuilder() {
        return this.modelBuilder;
    }

    public KerasModel(KerasModelBuilder modelBuilder) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
        this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(), modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(), modelBuilder.isEnforceTrainingConfig());
    }

    protected KerasModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot, String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
        this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
        this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config);
        this.enforceTrainingConfig = enforceTrainingConfig;
        if (!modelConfig.containsKey(config.getFieldClassName())) {
            throw new InvalidKerasConfigurationException("Could not determine Keras model class (no " + config.getFieldClassName() + " field found)");
        }
        this.className = (String)modelConfig.get(config.getFieldClassName());
        if (!this.className.equals(config.getFieldClassNameModel())) {
            throw new InvalidKerasConfigurationException("Expected model class name " + config.getFieldClassNameModel() + " (found " + this.className + ")");
        }
        if (!modelConfig.containsKey(config.getModelFieldConfig())) {
            throw new InvalidKerasConfigurationException("Could not find model configuration details (no " + config.getModelFieldConfig() + " in model config)");
        }
        Map layerLists = (Map)modelConfig.get(config.getModelFieldConfig());
        if (!layerLists.containsKey(config.getModelFieldInputLayers())) {
            throw new InvalidKerasConfigurationException("Could not find list of input layers (no " + config.getModelFieldInputLayers() + " field found)");
        }
        this.inputLayerNames = new ArrayList();
        for (Object inputLayerNameObj : (List)layerLists.get(config.getModelFieldInputLayers())) {
            this.inputLayerNames.add((String)((List)inputLayerNameObj).get(0));
        }
        if (!layerLists.containsKey(config.getModelFieldOutputLayers())) {
            throw new InvalidKerasConfigurationException("Could not find list of output layers (no " + config.getModelFieldOutputLayers() + " field found)");
        }
        this.outputLayerNames = new ArrayList();
        for (Object outputLayerNameObj : (List)layerLists.get(config.getModelFieldOutputLayers())) {
            this.outputLayerNames.add((String)((List)outputLayerNameObj).get(0));
        }
        if (!layerLists.containsKey(config.getModelFieldLayers())) {
            throw new InvalidKerasConfigurationException("Could not find layer configurations (no " + config.getModelFieldLayers() + " field found)");
        }
        this.prepareLayers((List)layerLists.get(config.getModelFieldLayers()));
        if (trainingJson != null && enforceTrainingConfig) {
            this.importTrainingConfiguration(trainingJson);
        }
        this.inferOutputTypes();
        if (weightsArchive != null) {
            KerasModelUtils.importWeights(weightsArchive, weightsRoot, this.layers, this.kerasMajorVersion, this.kerasBackend);
        }
    }

    void prepareLayers(List<Object> layerConfigs) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this.layersOrdered = new ArrayList<KerasLayer>();
        this.layers = new HashMap<String, KerasLayer>();
        for (Object layerConfig : layerConfigs) {
            Map layerConfigMap = (Map)layerConfig;
            layerConfigMap.put(config.getFieldKerasVersion(), this.kerasMajorVersion);
            if (this.kerasMajorVersion == 2 && this.kerasBackend != null) {
                layerConfigMap.put(config.getFieldBackend(), this.kerasBackend);
            }
            KerasLayerConfiguration kerasLayerConf = new KerasLayer((Integer)Integer.valueOf((int)this.kerasMajorVersion)).conf;
            KerasLayer layer = KerasLayerUtils.getKerasLayerFromConfig(layerConfigMap, this.enforceTrainingConfig, kerasLayerConf, KerasLayer.customLayers, this.layers);
            this.layersOrdered.add(layer);
            this.layers.put(layer.getLayerName(), layer);
            if (layer instanceof KerasLstm) {
                boolean bl = this.useTruncatedBPTT = this.useTruncatedBPTT || ((KerasLstm)layer).getUnroll();
            }
            if (!(layer instanceof KerasSimpleRnn)) continue;
            this.useTruncatedBPTT = this.useTruncatedBPTT || ((KerasSimpleRnn)layer).getUnroll();
        }
    }

    void importTrainingConfiguration(String trainingConfigJson) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> trainingConfig = KerasModelUtils.parseJsonString(trainingConfigJson);
        ArrayList<KerasLoss> lossLayers = new ArrayList<KerasLoss>();
        if (!trainingConfig.containsKey(config.getTrainingLoss())) {
            throw new InvalidKerasConfigurationException("Could not determine training loss function (no " + config.getTrainingLoss() + " field found in training config)");
        }
        Object kerasLossObj = trainingConfig.get(config.getTrainingLoss());
        if (kerasLossObj instanceof String) {
            String kerasLoss = (String)kerasLossObj;
            for (String outputLayerName : this.outputLayerNames) {
                lossLayers.add(new KerasLoss(outputLayerName + "_loss", outputLayerName, kerasLoss));
            }
        } else if (kerasLossObj instanceof Map) {
            Map kerasLossMap = (Map)kerasLossObj;
            for (String outputLayerName : kerasLossMap.keySet()) {
                Object kerasLoss = kerasLossMap.get(outputLayerName);
                if (kerasLoss instanceof String) {
                    lossLayers.add(new KerasLoss(outputLayerName + "_loss", outputLayerName, (String)kerasLoss));
                    continue;
                }
                throw new InvalidKerasConfigurationException("Unknown Keras loss " + kerasLoss.toString());
            }
        }
        this.outputLayerNames.clear();
        for (KerasLayer kerasLayer : lossLayers) {
            this.layersOrdered.add(kerasLayer);
            this.layers.put(kerasLayer.getLayerName(), kerasLayer);
            this.outputLayerNames.add(kerasLayer.getLayerName());
        }
    }

    void inferOutputTypes() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this.outputTypes = new HashMap<String, InputType>();
        for (KerasLayer layer : this.layersOrdered) {
            InputType outputType;
            if (layer instanceof KerasInput) {
                outputType = layer.getOutputType(new InputType[0]);
                this.truncatedBPTT = ((KerasInput)layer).getTruncatedBptt();
            } else {
                InputType[] inputTypes = new InputType[layer.getInboundLayerNames().size()];
                int i = 0;
                for (String inboundLayerName : layer.getInboundLayerNames()) {
                    inputTypes[i++] = this.outputTypes.get(inboundLayerName);
                }
                outputType = layer.getOutputType(inputTypes);
            }
            this.outputTypes.put(layer.getLayerName(), outputType);
        }
    }

    public ComputationGraphConfiguration getComputationGraphConfiguration() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        if (!this.className.equals(config.getFieldClassNameModel()) && !this.className.equals(config.getFieldClassNameSequential())) {
            throw new InvalidKerasConfigurationException("Keras model class name " + this.className + " incompatible with ComputationGraph");
        }
        NeuralNetConfiguration.Builder modelBuilder = new NeuralNetConfiguration.Builder();
        ComputationGraphConfiguration.GraphBuilder graphBuilder = modelBuilder.graphBuilder();
        String[] inputLayerNameArray = new String[this.inputLayerNames.size()];
        this.inputLayerNames.toArray(inputLayerNameArray);
        graphBuilder.addInputs(inputLayerNameArray);
        ArrayList<InputType> inputTypeList = new ArrayList<InputType>();
        for (String inputLayerName : this.inputLayerNames) {
            inputTypeList.add(this.layers.get(inputLayerName).getOutputType(new InputType[0]));
        }
        InputType[] inputTypes = new InputType[inputTypeList.size()];
        inputTypeList.toArray(inputTypes);
        graphBuilder.setInputTypes(inputTypes);
        String[] outputLayerNameArray = new String[this.outputLayerNames.size()];
        this.outputLayerNames.toArray(outputLayerNameArray);
        graphBuilder.setOutputs(outputLayerNameArray);
        HashMap<String, InputPreProcessor> preprocessors = new HashMap<String, InputPreProcessor>();
        for (KerasLayer layer : this.layersOrdered) {
            List<String> inboundLayerNames = layer.getInboundLayerNames();
            String[] inboundLayerNamesArray = new String[inboundLayerNames.size()];
            inboundLayerNames.toArray(inboundLayerNamesArray);
            ArrayList<InputType> inboundTypeList = new ArrayList<InputType>();
            for (String layerName : inboundLayerNames) {
                inboundTypeList.add(this.outputTypes.get(layerName));
            }
            InputType[] inboundTypeArray = new InputType[inboundTypeList.size()];
            inboundTypeList.toArray(inboundTypeArray);
            InputPreProcessor preprocessor = layer.getInputPreprocessor(inboundTypeArray);
            if (layer.isLayer()) {
                if (preprocessor != null) {
                    preprocessors.put(layer.getLayerName(), preprocessor);
                }
                graphBuilder.addLayer(layer.getLayerName(), layer.getLayer(), inboundLayerNamesArray);
                if (this.outputLayerNames.contains(layer.getLayerName()) && !(layer.getLayer() instanceof IOutputLayer)) {
                    log.warn("Model cannot be trained: output layer " + layer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
                }
            } else if (layer.isVertex()) {
                if (preprocessor != null) {
                    preprocessors.put(layer.getLayerName(), preprocessor);
                }
                graphBuilder.addVertex(layer.getLayerName(), layer.getVertex(), inboundLayerNamesArray);
                if (this.outputLayerNames.contains(layer.getLayerName()) && !(layer.getVertex() instanceof IOutputLayer)) {
                    log.warn("Model cannot be trained: output vertex " + layer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
                }
            } else if (layer.isInputPreProcessor()) {
                if (preprocessor == null) {
                    throw new UnsupportedKerasConfigurationException("Layer " + layer.getLayerName() + " could not be mapped to Layer, Vertex, or InputPreProcessor");
                }
                graphBuilder.addVertex(layer.getLayerName(), (GraphVertex)new PreprocessorVertex(preprocessor), inboundLayerNamesArray);
            }
            if (!this.outputLayerNames.contains(layer.getLayerName())) continue;
            log.warn("Model cannot be trained: output " + layer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
        }
        graphBuilder.setInputPreProcessors(preprocessors);
        if (this.useTruncatedBPTT && this.truncatedBPTT > 0) {
            graphBuilder.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(this.truncatedBPTT).tBPTTBackwardLength(this.truncatedBPTT);
        } else {
            graphBuilder.backpropType(BackpropType.Standard);
        }
        return graphBuilder.build();
    }

    public ComputationGraph getComputationGraph() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return this.getComputationGraph(true);
    }

    public ComputationGraph getComputationGraph(boolean importWeights) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        ComputationGraph model = new ComputationGraph(this.getComputationGraphConfiguration());
        model.init();
        if (importWeights) {
            model = (ComputationGraph)KerasModelUtils.copyWeightsToModel((Model)model, this.layers);
        }
        return model;
    }
}

