/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasSequentialModel
extends KerasModel {
    private static final Logger log = LoggerFactory.getLogger(KerasSequentialModel.class);

    public KerasSequentialModel(KerasModelBuilder modelBuilder) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
        this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(), modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(), modelBuilder.isEnforceTrainingConfig(), modelBuilder.getInputShape());
    }

    public KerasSequentialModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot, String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig, int[] inputShape) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasLayer inputLayer;
        List layerList;
        Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
        this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
        this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config);
        this.enforceTrainingConfig = enforceTrainingConfig;
        if (!modelConfig.containsKey(config.getFieldClassName())) {
            throw new InvalidKerasConfigurationException("Could not determine Keras model class (no " + config.getFieldClassName() + " field found)");
        }
        this.className = (String)modelConfig.get(config.getFieldClassName());
        if (!this.className.equals(config.getFieldClassNameSequential())) {
            throw new InvalidKerasConfigurationException("Model class name must be " + config.getFieldClassNameSequential() + " (found " + this.className + ")");
        }
        if (!modelConfig.containsKey(config.getModelFieldConfig())) {
            throw new InvalidKerasConfigurationException("Could not find layer configurations (no " + config.getModelFieldConfig() + " field found)");
        }
        try {
            layerList = (List)modelConfig.get(config.getModelFieldConfig());
        }
        catch (Exception e) {
            HashMap layerMap = (HashMap)modelConfig.get(config.getModelFieldConfig());
            layerList = (List)layerMap.get("layers");
        }
        Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair = this.prepareLayers(layerList);
        this.layers = (Map)layerPair.getFirst();
        this.layersOrdered = (List)layerPair.getSecond();
        if (this.layersOrdered.get(0) instanceof KerasInput) {
            inputLayer = (KerasLayer)this.layersOrdered.get(0);
        } else {
            int[] firstLayerInputShape = ((KerasLayer)this.layersOrdered.get(0)).getInputShape();
            Preconditions.checkState((ArrayUtil.prod((int[])firstLayerInputShape) > 0 ? 1 : 0) != 0, (String)"Input shape must not be zero!");
            inputLayer = new KerasInput("input1", firstLayerInputShape);
            inputLayer.setDimOrder(((KerasLayer)this.layersOrdered.get(0)).getDimOrder());
            this.layers.put(inputLayer.getLayerName(), inputLayer);
            this.layersOrdered.add(0, inputLayer);
        }
        this.inputLayerNames = new ArrayList<String>(Collections.singletonList(inputLayer.getLayerName()));
        this.outputLayerNames = new ArrayList<String>(Collections.singletonList(((KerasLayer)this.layersOrdered.get(this.layersOrdered.size() - 1)).getLayerName()));
        KerasLayer prevLayer = null;
        for (KerasLayer layer : this.layersOrdered) {
            if (prevLayer != null) {
                layer.setInboundLayerNames(Collections.singletonList(prevLayer.getLayerName()));
            }
            prevLayer = layer;
        }
        if (enforceTrainingConfig) {
            if (trainingJson != null) {
                this.importTrainingConfiguration(trainingJson);
            } else {
                log.warn("If enforceTrainingConfig is true, a training configuration object has to be provided. Usually the only practical way to do this is to store your keras model with `model.save('model_path.h5'. If you store model config and weights separately no training configuration is attached.");
            }
        }
        this.outputTypes = this.inferOutputTypes(inputShape);
        if (weightsArchive != null) {
            KerasModelUtils.importWeights(weightsArchive, weightsRoot, this.layers, this.kerasMajorVersion, this.kerasBackend);
        }
    }

    public KerasSequentialModel() {
    }

    public MultiLayerConfiguration getMultiLayerConfiguration() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        if (!this.className.equals(config.getFieldClassNameSequential())) {
            throw new InvalidKerasConfigurationException("Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
        }
        if (this.inputLayerNames.size() != 1) {
            throw new InvalidKerasConfigurationException("MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
        }
        if (this.outputLayerNames.size() != 1) {
            throw new InvalidKerasConfigurationException("MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
        }
        NeuralNetConfiguration.Builder modelBuilder = new NeuralNetConfiguration.Builder();
        if (this.optimizer != null) {
            modelBuilder.updater(this.optimizer);
        }
        NeuralNetConfiguration.ListBuilder listBuilder = modelBuilder.list();
        listBuilder.overrideNinUponBuild(false);
        KerasLayer prevLayer = null;
        int layerIndex = 0;
        for (KerasLayer layer : this.layersOrdered) {
            if (layer.isLayer()) {
                int nbInbound = layer.getInboundLayerNames().size();
                if (nbInbound != 1) {
                    throw new InvalidKerasConfigurationException("Layers in MultiLayerConfiguration must have exactly one inbound layer (found " + nbInbound + " for layer " + layer.getLayerName() + ")");
                }
                if (prevLayer != null) {
                    InputType outputType;
                    InputPreProcessor preprocessor;
                    InputType[] inputTypes = new InputType[1];
                    if (prevLayer.isInputPreProcessor()) {
                        inputTypes[0] = (InputType)this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
                        preprocessor = prevLayer.getInputPreprocessor(inputTypes);
                        outputType = preprocessor.getOutputType(inputTypes[0]);
                        layer.getLayer().setNIn(outputType, listBuilder.isOverrideNinUponBuild());
                    } else {
                        inputTypes[0] = (InputType)this.outputTypes.get(prevLayer.getLayerName());
                        preprocessor = layer.getInputPreprocessor(inputTypes);
                        if (preprocessor != null) {
                            outputType = preprocessor.getOutputType(inputTypes[0]);
                            layer.getLayer().setNIn(outputType, listBuilder.isOverrideNinUponBuild());
                        } else {
                            layer.getLayer().setNIn(inputTypes[0], listBuilder.isOverrideNinUponBuild());
                        }
                    }
                    if (preprocessor != null) {
                        listBuilder.inputPreProcessor(Integer.valueOf(layerIndex), preprocessor);
                    }
                }
                listBuilder.layer(layerIndex++, layer.getLayer());
            } else if (layer.getVertex() != null) {
                throw new InvalidKerasConfigurationException("Cannot add vertex to MultiLayerConfiguration (class name " + layer.getClassName() + ", layer name " + layer.getLayerName() + ")");
            }
            prevLayer = layer;
        }
        if (this.useTruncatedBPTT && this.truncatedBPTT > 0) {
            listBuilder.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(this.truncatedBPTT).tBPTTBackwardLength(this.truncatedBPTT);
        } else {
            listBuilder.backpropType(BackpropType.Standard);
        }
        MultiLayerConfiguration build = listBuilder.build();
        return build;
    }

    public MultiLayerNetwork getMultiLayerNetwork() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return this.getMultiLayerNetwork(true);
    }

    public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        MultiLayerNetwork model = new MultiLayerNetwork(this.getMultiLayerConfiguration());
        model.init();
        if (importWeights) {
            model = (MultiLayerNetwork)KerasModelUtils.copyWeightsToModel((Model)model, this.layers);
        }
        return model;
    }
}

