/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;

public class KerasSequentialModel
extends KerasModel {
    public KerasSequentialModel(KerasModel.ModelBuilder modelBuilder) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
        this(modelBuilder.modelJson, modelBuilder.modelYaml, modelBuilder.trainingJson, modelBuilder.weights, modelBuilder.train);
    }

    public KerasSequentialModel(String modelJson, String modelYaml, String trainingJson, Map<String, Map<String, INDArray>> weights, boolean train) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> classNameAndLayers;
        if (modelJson != null) {
            classNameAndLayers = KerasSequentialModel.parseJsonString(modelJson);
        } else if (modelYaml != null) {
            classNameAndLayers = KerasSequentialModel.parseYamlString(modelYaml);
        } else {
            throw new InvalidKerasConfigurationException("Requires model configuration as either JSON or YAML string.");
        }
        this.className = (String)KerasSequentialModel.checkAndGetModelField(classNameAndLayers, "class_name");
        if (!this.className.equals("Sequential")) {
            throw new InvalidKerasConfigurationException("Model class name must be Sequential (found " + this.className + ")");
        }
        this.train = train;
        this.helperPrepareLayers((List)KerasSequentialModel.checkAndGetModelField(classNameAndLayers, "config"));
        int[] inputShape = ((KerasLayer)this.layers.get(this.layerNamesOrdered.get(0))).getInputShape();
        KerasLayer inputLayer = KerasLayer.createInputLayer("input1", inputShape);
        this.layers.put(inputLayer.getName(), inputLayer);
        this.inputLayerNames = new ArrayList<String>(Arrays.asList(inputLayer.getName()));
        this.outputLayerNames = new ArrayList<String>(Arrays.asList((String)this.layerNamesOrdered.get(this.layerNamesOrdered.size() - 1)));
        this.layerNamesOrdered.add(0, inputLayer.getName());
        String prevLayerName = null;
        for (String layerName : this.layerNamesOrdered) {
            if (prevLayerName != null) {
                ((KerasLayer)this.layers.get(layerName)).setInboundLayerNames(Arrays.asList(prevLayerName));
            }
            prevLayerName = layerName;
        }
        this.helperPrepareGraph();
        if (trainingJson != null) {
            this.helperImportTrainingConfiguration(trainingJson);
        }
        this.weights = weights;
    }

    protected KerasSequentialModel() {
    }

    public MultiLayerConfiguration getMultiLayerConfiguration() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        if (!this.className.equals("Sequential")) {
            throw new InvalidKerasConfigurationException("Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
        }
        if (this.inputLayerNames.size() != 1) {
            throw new InvalidKerasConfigurationException("MultiLayeNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
        }
        if (this.outputLayerNames.size() != 1) {
            throw new InvalidKerasConfigurationException("MultiLayeNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
        }
        NeuralNetConfiguration.Builder modelBuilder = new NeuralNetConfiguration.Builder();
        NeuralNetConfiguration.ListBuilder listBuilder = modelBuilder.list();
        int layerIndex = 0;
        for (String layerName : this.layerNamesOrdered) {
            KerasLayer layer = (KerasLayer)this.layers.get(layerName);
            if (!layer.isDl4jLayer()) continue;
            listBuilder.layer(layerIndex++, layer.getDl4jLayer());
        }
        InputType inputType = this.inferInputType((String)this.inputLayerNames.get(0));
        if (inputType != null) {
            listBuilder.setInputType(inputType);
        }
        if (this.truncatedBPTT == 0) {
            throw new UnsupportedKerasConfigurationException("Cannot import recurrent models without fixed length sequence input.");
        }
        if (this.truncatedBPTT > 0) {
            listBuilder.tBPTTForwardLength(this.truncatedBPTT).tBPTTBackwardLength(this.truncatedBPTT);
        }
        return listBuilder.build();
    }

    public MultiLayerNetwork getMultiLayerNetwork() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        MultiLayerNetwork model = this.getMultiLayerNetwork(true);
        return model;
    }

    public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        MultiLayerNetwork model = new MultiLayerNetwork(this.getMultiLayerConfiguration());
        model.init();
        if (importWeights) {
            model = (MultiLayerNetwork)KerasSequentialModel.copyWeightsToModel((Model)model, this.weights, this.layers);
        }
        return model;
    }
}

