/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasSequentialModel
extends KerasModel {
    private static final Logger log = LoggerFactory.getLogger(KerasSequentialModel.class);

    public KerasSequentialModel(KerasModel.ModelBuilder modelBuilder) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
        this(modelBuilder.modelJson, modelBuilder.modelYaml, modelBuilder.weightsArchive, modelBuilder.weightsRoot, modelBuilder.trainingJson, modelBuilder.trainingArchive, modelBuilder.enforceTrainingConfig);
    }

    public KerasSequentialModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot, String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasLayer inputLayer;
        Map<String, Object> modelConfig;
        if (modelJson != null) {
            modelConfig = KerasSequentialModel.parseJsonString(modelJson);
        } else if (modelYaml != null) {
            modelConfig = KerasSequentialModel.parseYamlString(modelYaml);
        } else {
            throw new InvalidKerasConfigurationException("Requires model configuration as either JSON or YAML string.");
        }
        this.enforceTrainingConfig = enforceTrainingConfig;
        if (!modelConfig.containsKey("class_name")) {
            throw new InvalidKerasConfigurationException("Could not determine Keras model class (no class_name field found)");
        }
        this.className = (String)modelConfig.get("class_name");
        if (!this.className.equals("Sequential")) {
            throw new InvalidKerasConfigurationException("Model class name must be Sequential (found " + this.className + ")");
        }
        if (!modelConfig.containsKey("config")) {
            throw new InvalidKerasConfigurationException("Could not find layer configurations (no config field found)");
        }
        this.helperPrepareLayers((List)modelConfig.get("config"));
        if (this.layersOrdered.get(0) instanceof KerasInput) {
            inputLayer = (KerasLayer)this.layersOrdered.get(0);
        } else {
            int[] inputShape = ((KerasLayer)this.layersOrdered.get(0)).getInputShape();
            inputLayer = new KerasInput("input1", inputShape);
            inputLayer.setDimOrder(((KerasLayer)this.layersOrdered.get(0)).getDimOrder());
            this.layers.put(inputLayer.getLayerName(), inputLayer);
            this.layersOrdered.add(0, inputLayer);
        }
        this.inputLayerNames = new ArrayList<String>(Arrays.asList(inputLayer.getLayerName()));
        this.outputLayerNames = new ArrayList<String>(Arrays.asList(((KerasLayer)this.layersOrdered.get(this.layersOrdered.size() - 1)).getLayerName()));
        KerasLayer prevLayer = null;
        for (KerasLayer layer : this.layersOrdered) {
            if (prevLayer != null) {
                layer.setInboundLayerNames(Arrays.asList(prevLayer.getLayerName()));
            }
            prevLayer = layer;
        }
        if (trainingJson != null) {
            this.helperImportTrainingConfiguration(trainingJson);
        }
        this.helperInferOutputTypes();
        if (weightsArchive != null) {
            this.helperImportWeights(weightsArchive, weightsRoot);
        }
    }

    protected KerasSequentialModel() {
    }

    public MultiLayerConfiguration getMultiLayerConfiguration() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        if (!this.className.equals("Sequential")) {
            throw new InvalidKerasConfigurationException("Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
        }
        if (this.inputLayerNames.size() != 1) {
            throw new InvalidKerasConfigurationException("MultiLayeNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
        }
        if (this.outputLayerNames.size() != 1) {
            throw new InvalidKerasConfigurationException("MultiLayeNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
        }
        NeuralNetConfiguration.Builder modelBuilder = new NeuralNetConfiguration.Builder();
        NeuralNetConfiguration.ListBuilder listBuilder = modelBuilder.list();
        KerasLayer prevLayer = null;
        int layerIndex = 0;
        for (KerasLayer layer : this.layersOrdered) {
            if (layer.usesRegularization()) {
                modelBuilder.setUseRegularization(true);
            }
            if (layer.isLayer()) {
                int nbInbound = layer.getInboundLayerNames().size();
                if (nbInbound != 1) {
                    throw new InvalidKerasConfigurationException("Layers in MultiLayerConfiguration must have exactly one inbound layer (found " + nbInbound + " for layer " + layer.getLayerName() + ")");
                }
                if (prevLayer != null) {
                    InputType[] inputTypes = new InputType[1];
                    InputPreProcessor preprocessor = null;
                    if (prevLayer.isInputPreProcessor()) {
                        inputTypes[0] = (InputType)this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
                        preprocessor = prevLayer.getInputPreprocessor(inputTypes);
                    } else {
                        inputTypes[0] = (InputType)this.outputTypes.get(prevLayer.getLayerName());
                        preprocessor = layer.getInputPreprocessor(inputTypes);
                    }
                    if (preprocessor != null) {
                        listBuilder.inputPreProcessor(Integer.valueOf(layerIndex), preprocessor);
                    }
                }
                listBuilder.layer(layerIndex++, layer.getLayer());
                if (this.outputLayerNames.contains(layer.getLayerName()) && !(layer.getLayer() instanceof IOutputLayer)) {
                    log.warn("Model cannot be trained: output layer " + layer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
                }
            } else if (layer.getVertex() != null) {
                throw new InvalidKerasConfigurationException("Cannot add vertex to MultiLayerConfiguration (class name " + layer.getClassName() + ", layer name " + layer.getLayerName() + ")");
            }
            prevLayer = layer;
        }
        InputType inputType = ((KerasLayer)this.layersOrdered.get(0)).getOutputType(new InputType[0]);
        if (inputType != null) {
            listBuilder.setInputType(inputType);
        }
        if (this.useTruncatedBPTT && this.truncatedBPTT > 0) {
            listBuilder.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(this.truncatedBPTT).tBPTTBackwardLength(this.truncatedBPTT);
        } else {
            listBuilder.backpropType(BackpropType.Standard);
        }
        return listBuilder.build();
    }

    public MultiLayerNetwork getMultiLayerNetwork() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        MultiLayerNetwork model = this.getMultiLayerNetwork(true);
        return model;
    }

    public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        MultiLayerNetwork model = new MultiLayerNetwork(this.getMultiLayerConfiguration());
        model.init();
        if (importWeights) {
            model = (MultiLayerNetwork)this.helperCopyWeightsToModel((Model)model);
        }
        return model;
    }
}

