/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasLoss;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasLstm;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.core.type.TypeReference;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasModel {
    private static final Logger log = LoggerFactory.getLogger(KerasModel.class);
    public static final String MODEL_FIELD_CLASS_NAME = "class_name";
    public static final String MODEL_CLASS_NAME_SEQUENTIAL = "Sequential";
    public static final String MODEL_CLASS_NAME_MODEL = "Model";
    public static final String MODEL_FIELD_CONFIG = "config";
    public static final String MODEL_CONFIG_FIELD_LAYERS = "layers";
    public static final String MODEL_CONFIG_FIELD_INPUT_LAYERS = "input_layers";
    public static final String MODEL_CONFIG_FIELD_OUTPUT_LAYERS = "output_layers";
    public static final String TRAINING_CONFIG_FIELD_LOSS = "loss";
    public static final String HDF5_MODEL_WEIGHTS_ROOT = "model_weights";
    public static final String HDF5_MODEL_CONFIG_ATTRIBUTE = "model_config";
    public static final String HDF5_TRAINING_CONFIG_ATTRIBUTE = "training_config";
    protected String className;
    protected boolean enforceTrainingConfig;
    protected List<KerasLayer> layersOrdered;
    protected Map<String, KerasLayer> layers;
    protected Map<String, InputType> outputTypes;
    protected ArrayList<String> inputLayerNames;
    protected ArrayList<String> outputLayerNames;
    protected boolean useTruncatedBPTT = false;
    protected int truncatedBPTT = 0;

    public KerasModel(ModelBuilder modelBuilder) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
        this(modelBuilder.modelJson, modelBuilder.modelYaml, modelBuilder.weightsArchive, modelBuilder.weightsRoot, modelBuilder.trainingJson, modelBuilder.trainingArchive, modelBuilder.enforceTrainingConfig);
    }

    protected KerasModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot, String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> modelConfig;
        if (modelJson != null) {
            modelConfig = KerasModel.parseJsonString(modelJson);
        } else if (modelYaml != null) {
            modelConfig = KerasModel.parseYamlString(modelYaml);
        } else {
            throw new InvalidKerasConfigurationException("Requires model configuration not found.");
        }
        this.enforceTrainingConfig = enforceTrainingConfig;
        if (!modelConfig.containsKey(MODEL_FIELD_CLASS_NAME)) {
            throw new InvalidKerasConfigurationException("Could not determine Keras model class (no class_name field found)");
        }
        this.className = (String)modelConfig.get(MODEL_FIELD_CLASS_NAME);
        if (!this.className.equals(MODEL_CLASS_NAME_MODEL)) {
            throw new InvalidKerasConfigurationException("Expected model class name Model (found " + this.className + ")");
        }
        if (!modelConfig.containsKey(MODEL_FIELD_CONFIG)) {
            throw new InvalidKerasConfigurationException("Could not find model configuration details (no config in model config)");
        }
        Map layerLists = (Map)modelConfig.get(MODEL_FIELD_CONFIG);
        if (!layerLists.containsKey(MODEL_CONFIG_FIELD_INPUT_LAYERS)) {
            throw new InvalidKerasConfigurationException("Could not find list of input layers (no input_layers field found)");
        }
        this.inputLayerNames = new ArrayList();
        for (Object inputLayerNameObj : (List)layerLists.get(MODEL_CONFIG_FIELD_INPUT_LAYERS)) {
            this.inputLayerNames.add((String)((List)inputLayerNameObj).get(0));
        }
        if (!layerLists.containsKey(MODEL_CONFIG_FIELD_OUTPUT_LAYERS)) {
            throw new InvalidKerasConfigurationException("Could not find list of output layers (no output_layers field found)");
        }
        this.outputLayerNames = new ArrayList();
        for (Object outputLayerNameObj : (List)layerLists.get(MODEL_CONFIG_FIELD_OUTPUT_LAYERS)) {
            this.outputLayerNames.add((String)((List)outputLayerNameObj).get(0));
        }
        if (!layerLists.containsKey(MODEL_CONFIG_FIELD_LAYERS)) {
            throw new InvalidKerasConfigurationException("Could not find layer configurations (no layers field found)");
        }
        this.helperPrepareLayers((List)layerLists.get(MODEL_CONFIG_FIELD_LAYERS));
        if (trainingJson != null) {
            this.helperImportTrainingConfiguration(trainingJson);
        }
        this.helperInferOutputTypes();
        if (weightsArchive != null) {
            this.helperImportWeights(weightsArchive, weightsRoot);
        }
    }

    protected void helperPrepareLayers(List<Object> layerConfigs) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this.layersOrdered = new ArrayList<KerasLayer>();
        this.layers = new HashMap<String, KerasLayer>();
        KerasLayer.DimOrder dimOrder = KerasLayer.DimOrder.NONE;
        for (Object layerConfig : layerConfigs) {
            KerasLayer layer = KerasLayer.getKerasLayerFromConfig((Map)layerConfig, this.enforceTrainingConfig);
            if (dimOrder == KerasLayer.DimOrder.NONE && layer.getDimOrder() != KerasLayer.DimOrder.NONE) {
                dimOrder = layer.getDimOrder();
            }
            this.layersOrdered.add(layer);
            this.layers.put(layer.getLayerName(), layer);
            if (!(layer instanceof KerasLstm)) continue;
            this.useTruncatedBPTT = this.useTruncatedBPTT || ((KerasLstm)layer).getUnroll();
        }
        for (KerasLayer layer : this.layersOrdered) {
            if (layer.getDimOrder() == KerasLayer.DimOrder.NONE) {
                layer.setDimOrder(dimOrder);
                continue;
            }
            if (layer.getDimOrder() == dimOrder) continue;
            throw new UnsupportedKerasConfigurationException("Keras layer " + layer.getLayerName() + " has conflicting dim_ordering " + (Object)((Object)layer.getDimOrder()) + " (vs. dimOrder)");
        }
    }

    protected void helperImportTrainingConfiguration(String trainingConfigJson) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> trainingConfig = KerasModel.parseJsonString(trainingConfigJson);
        ArrayList<KerasLoss> lossLayers = new ArrayList<KerasLoss>();
        if (!trainingConfig.containsKey(TRAINING_CONFIG_FIELD_LOSS)) {
            throw new InvalidKerasConfigurationException("Could not determine training loss function (no loss field found in training config)");
        }
        Object kerasLossObj = trainingConfig.get(TRAINING_CONFIG_FIELD_LOSS);
        if (kerasLossObj instanceof String) {
            String kerasLoss = (String)kerasLossObj;
            for (String outputLayerName : this.outputLayerNames) {
                lossLayers.add(new KerasLoss(outputLayerName + "_loss", outputLayerName, kerasLoss));
            }
        } else if (kerasLossObj instanceof Map) {
            Map kerasLossMap = (Map)kerasLossObj;
            for (String outputLayerName : kerasLossMap.keySet()) {
                Object kerasLoss = kerasLossMap.get(outputLayerName);
                if (kerasLoss instanceof String) {
                    lossLayers.add(new KerasLoss(outputLayerName + "_loss", outputLayerName, (String)kerasLoss));
                    continue;
                }
                throw new InvalidKerasConfigurationException("Unknown Keras loss " + kerasLoss.toString());
            }
        }
        this.outputLayerNames.clear();
        for (KerasLayer kerasLayer : lossLayers) {
            this.layersOrdered.add(kerasLayer);
            this.layers.put(kerasLayer.getLayerName(), kerasLayer);
            this.outputLayerNames.add(kerasLayer.getLayerName());
        }
    }

    protected void helperInferOutputTypes() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this.outputTypes = new HashMap<String, InputType>();
        for (KerasLayer layer : this.layersOrdered) {
            InputType outputType = null;
            if (layer instanceof KerasInput) {
                outputType = layer.getOutputType(new InputType[0]);
                this.truncatedBPTT = ((KerasInput)layer).getTruncatedBptt();
            } else {
                InputType[] inputTypes = new InputType[layer.getInboundLayerNames().size()];
                int i = 0;
                for (String inboundLayerName : layer.getInboundLayerNames()) {
                    inputTypes[i++] = this.outputTypes.get(inboundLayerName);
                }
                outputType = layer.getOutputType(inputTypes);
            }
            this.outputTypes.put(layer.getLayerName(), outputType);
        }
    }

    protected List<String> helperRecurseWeightsArchive(Hdf5Archive weightsArchive, String weightsRoot, String layerName) {
        return new LinkedList<String>();
    }

    protected void helperImportWeights(Hdf5Archive weightsArchive, String weightsRoot) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        boolean includesSlash = false;
        for (String string : this.layers.keySet()) {
            if (!string.contains("/")) continue;
            includesSlash = true;
        }
        List<String> layerGroups = !includesSlash ? (weightsRoot != null ? weightsArchive.getGroups(weightsRoot) : weightsArchive.getGroups(new String[0])) : new ArrayList<String>(this.layers.keySet());
        for (String layerName : layerGroups) {
            List<String> layerParamNames;
            String[] layerFragments = layerName.split("/");
            if (layerFragments.length > 1) {
                try {
                    layerParamNames = weightsRoot != null ? weightsArchive.getDataSets(weightsRoot, layerName + "/" + layerFragments[0]) : weightsArchive.getDataSets(layerName + "/" + layerFragments[0]);
                }
                catch (Exception e) {
                    layerParamNames = weightsRoot != null ? weightsArchive.getDataSets(weightsRoot, layerName) : weightsArchive.getDataSets(layerName);
                }
            } else {
                List<String> list = layerParamNames = weightsRoot != null ? weightsArchive.getDataSets(weightsRoot, layerName) : weightsArchive.getDataSets(layerName);
            }
            if (layerParamNames.isEmpty()) continue;
            if (!layerParamNames.isEmpty() && !this.layers.containsKey(layerName)) {
                throw new InvalidKerasConfigurationException("Found weights for layer not in model (named " + layerName + ")");
            }
            KerasLayer layer = this.layers.get(layerName);
            if (layerParamNames.size() != layer.getNumParams()) {
                throw new InvalidKerasConfigurationException("Found " + layerParamNames.size() + " weights for layer with " + layer.getNumParams() + " trainable params (named " + layerName + ")");
            }
            HashMap<String, INDArray> weights = new HashMap<String, INDArray>();
            for (String layerParamName : layerParamNames) {
                Matcher tfParamNbMatcher;
                Matcher tfSuffixMatcher;
                Matcher layerNameMatcher = Pattern.compile(layerFragments[layerFragments.length - 1]).matcher(layerParamName);
                if (!layerNameMatcher.find()) {
                    throw new InvalidKerasConfigurationException("Unable to parse layer/parameter name " + layerParamName + " for stored weights.");
                }
                String paramName = layerNameMatcher.replaceFirst("");
                Matcher paramNameMatcher = Pattern.compile("^_(.+)$").matcher(paramName);
                if (paramNameMatcher.find()) {
                    paramName = paramNameMatcher.group(1);
                }
                if ((tfSuffixMatcher = Pattern.compile(":\\d+?$").matcher(paramName)).find()) {
                    paramName = tfSuffixMatcher.replaceFirst("");
                }
                if ((tfParamNbMatcher = Pattern.compile("_\\d+$").matcher(paramName)).find()) {
                    paramName = tfParamNbMatcher.replaceFirst("");
                }
                INDArray paramValue = layerFragments.length > 1 ? (weightsRoot != null ? weightsArchive.readDataSet(layerFragments[0] + "/" + layerParamName, weightsRoot, layerName) : weightsArchive.readDataSet(layerParamName, layerName)) : (weightsRoot != null ? weightsArchive.readDataSet(layerParamName, weightsRoot, layerName) : weightsArchive.readDataSet(layerParamName, layerName));
                weights.put(paramName, paramValue);
            }
            layer.setWeights(weights);
        }
        HashSet<String> hashSet = new HashSet<String>(this.layers.keySet());
        hashSet.removeAll(layerGroups);
        for (String layerName : hashSet) {
            if (this.layers.get(layerName).getNumParams() <= 0) continue;
            throw new InvalidKerasConfigurationException("Could not find weights required for layer " + layerName);
        }
    }

    protected KerasModel() {
    }

    public ComputationGraphConfiguration getComputationGraphConfiguration() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        if (!this.className.equals(MODEL_CLASS_NAME_MODEL) && !this.className.equals(MODEL_CLASS_NAME_SEQUENTIAL)) {
            throw new InvalidKerasConfigurationException("Keras model class name " + this.className + " incompatible with ComputationGraph");
        }
        NeuralNetConfiguration.Builder modelBuilder = new NeuralNetConfiguration.Builder();
        ComputationGraphConfiguration.GraphBuilder graphBuilder = modelBuilder.graphBuilder();
        String[] inputLayerNameArray = new String[this.inputLayerNames.size()];
        this.inputLayerNames.toArray(inputLayerNameArray);
        graphBuilder.addInputs(inputLayerNameArray);
        ArrayList<InputType> inputTypeList = new ArrayList<InputType>();
        for (String inputLayerName : this.inputLayerNames) {
            inputTypeList.add(this.layers.get(inputLayerName).getOutputType(new InputType[0]));
        }
        InputType[] inputTypes = new InputType[inputTypeList.size()];
        inputTypeList.toArray(inputTypes);
        graphBuilder.setInputTypes(inputTypes);
        String[] outputLayerNameArray = new String[this.outputLayerNames.size()];
        this.outputLayerNames.toArray(outputLayerNameArray);
        graphBuilder.setOutputs(outputLayerNameArray);
        HashMap<String, InputPreProcessor> preprocessors = new HashMap<String, InputPreProcessor>();
        for (KerasLayer layer : this.layersOrdered) {
            List<String> inboundLayerNames = layer.getInboundLayerNames();
            String[] inboundLayerNamesArray = new String[inboundLayerNames.size()];
            inboundLayerNames.toArray(inboundLayerNamesArray);
            ArrayList<InputType> inboundTypeList = new ArrayList<InputType>();
            for (String layerName : inboundLayerNames) {
                inboundTypeList.add(this.outputTypes.get(layerName));
            }
            InputType[] inboundTypeArray = new InputType[inboundTypeList.size()];
            inboundTypeList.toArray(inboundTypeArray);
            InputPreProcessor preprocessor = layer.getInputPreprocessor(inboundTypeArray);
            if (layer.usesRegularization()) {
                modelBuilder.setUseRegularization(true);
            }
            if (layer.isLayer()) {
                if (preprocessor != null) {
                    preprocessors.put(layer.getLayerName(), preprocessor);
                }
                graphBuilder.addLayer(layer.getLayerName(), layer.getLayer(), inboundLayerNamesArray);
                if (this.outputLayerNames.contains(layer.getLayerName()) && !(layer.getLayer() instanceof IOutputLayer)) {
                    log.warn("Model cannot be trained: output layer " + layer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
                }
            } else if (layer.isVertex()) {
                if (preprocessor != null) {
                    preprocessors.put(layer.getLayerName(), preprocessor);
                }
                graphBuilder.addVertex(layer.getLayerName(), layer.getVertex(), inboundLayerNamesArray);
                if (this.outputLayerNames.contains(layer.getLayerName()) && !(layer.getVertex() instanceof IOutputLayer)) {
                    log.warn("Model cannot be trained: output vertex " + layer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
                }
            } else if (layer.isInputPreProcessor()) {
                if (preprocessor == null) {
                    throw new UnsupportedKerasConfigurationException("Layer " + layer.getLayerName() + " could not be mapped to Layer, Vertex, or InputPreProcessor");
                }
                graphBuilder.addVertex(layer.getLayerName(), (GraphVertex)new PreprocessorVertex(preprocessor), inboundLayerNamesArray);
            }
            if (!this.outputLayerNames.contains(layer.getLayerName())) continue;
            log.warn("Model cannot be trained: output " + layer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
        }
        graphBuilder.setInputPreProcessors(preprocessors);
        if (this.useTruncatedBPTT && this.truncatedBPTT > 0) {
            graphBuilder.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(this.truncatedBPTT).tBPTTBackwardLength(this.truncatedBPTT);
        } else {
            graphBuilder.backpropType(BackpropType.Standard);
        }
        return graphBuilder.build();
    }

    public ComputationGraph getComputationGraph() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return this.getComputationGraph(true);
    }

    public ComputationGraph getComputationGraph(boolean importWeights) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        ComputationGraph model = new ComputationGraph(this.getComputationGraphConfiguration());
        model.init();
        if (importWeights) {
            model = (ComputationGraph)this.helperCopyWeightsToModel((Model)model);
        }
        return model;
    }

    public static Map<String, Object> parseJsonString(String json) throws IOException {
        ObjectMapper mapper = new ObjectMapper();
        TypeReference<HashMap<String, Object>> typeRef = new TypeReference<HashMap<String, Object>>(){};
        return (Map)mapper.readValue(json, (TypeReference)typeRef);
    }

    public static Map<String, Object> parseYamlString(String json) throws IOException {
        ObjectMapper mapper = new ObjectMapper((JsonFactory)new YAMLFactory());
        TypeReference<HashMap<String, Object>> typeRef = new TypeReference<HashMap<String, Object>>(){};
        return (Map)mapper.readValue(json, (TypeReference)typeRef);
    }

    protected Model helperCopyWeightsToModel(Model model) throws InvalidKerasConfigurationException {
        Layer[] layersFromModel = model instanceof MultiLayerNetwork ? ((MultiLayerNetwork)model).getLayers() : ((ComputationGraph)model).getLayers();
        HashSet<String> layerNames = new HashSet<String>(this.layers.keySet());
        for (Layer layer : layersFromModel) {
            String layerName = layer.conf().getLayer().getLayerName();
            if (!this.layers.containsKey(layerName)) {
                throw new InvalidKerasConfigurationException("No weights found for layer in model (named " + layerName + ")");
            }
            this.layers.get(layerName).copyWeightsToLayer(layer);
            layerNames.remove(layerName);
        }
        for (String layerName : layerNames) {
            if (this.layers.get(layerName).getNumParams() <= 0) continue;
            throw new InvalidKerasConfigurationException("Attemping to copy weights for layer not in model (named " + layerName + ")");
        }
        return model;
    }

    static class ModelBuilder
    implements Cloneable {
        protected String modelJson = null;
        protected String modelYaml = null;
        protected String trainingJson = null;
        protected Hdf5Archive weightsArchive = null;
        protected String weightsRoot = null;
        protected Hdf5Archive trainingArchive = null;
        protected boolean enforceTrainingConfig = false;

        public ModelBuilder modelJson(String modelJson) {
            this.modelJson = modelJson;
            return this;
        }

        public ModelBuilder modelJsonFilename(String modelJsonFilename) throws IOException {
            this.modelJson = new String(Files.readAllBytes(Paths.get(modelJsonFilename, new String[0])));
            return this;
        }

        public ModelBuilder modelJsonInputStream(InputStream modelJsonInputStream) throws IOException {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            IOUtils.copy((InputStream)modelJsonInputStream, (OutputStream)byteArrayOutputStream);
            this.modelJson = new String(byteArrayOutputStream.toByteArray());
            return this;
        }

        public ModelBuilder modelYaml(String modelYaml) {
            this.modelYaml = modelYaml;
            return this;
        }

        public ModelBuilder modelYamlFilename(String modelYamlFilename) throws IOException {
            this.modelJson = new String(Files.readAllBytes(Paths.get(modelYamlFilename, new String[0])));
            return this;
        }

        public ModelBuilder modelYamlInputStream(InputStream modelYamlInputStream) throws IOException {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            IOUtils.copy((InputStream)modelYamlInputStream, (OutputStream)byteArrayOutputStream);
            this.modelJson = new String(byteArrayOutputStream.toByteArray());
            return this;
        }

        public ModelBuilder trainingJson(String trainingJson) {
            this.trainingJson = trainingJson;
            return this;
        }

        public ModelBuilder trainingJsonInputStream(InputStream trainingJsonInputStream) throws IOException {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            IOUtils.copy((InputStream)trainingJsonInputStream, (OutputStream)byteArrayOutputStream);
            this.trainingJson = new String(byteArrayOutputStream.toByteArray());
            return this;
        }

        public ModelBuilder modelHdf5Filename(String modelHdf5Filename) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
            this.weightsArchive = this.trainingArchive = new Hdf5Archive(modelHdf5Filename);
            this.weightsRoot = KerasModel.HDF5_MODEL_WEIGHTS_ROOT;
            if (!this.weightsArchive.hasAttribute(KerasModel.HDF5_MODEL_CONFIG_ATTRIBUTE, new String[0])) {
                throw new InvalidKerasConfigurationException("Model configuration attribute missing from " + modelHdf5Filename + " archive.");
            }
            this.modelJson = this.weightsArchive.readAttributeAsJson(KerasModel.HDF5_MODEL_CONFIG_ATTRIBUTE, new String[0]);
            if (this.trainingArchive.hasAttribute(KerasModel.HDF5_TRAINING_CONFIG_ATTRIBUTE, new String[0])) {
                this.trainingJson = this.trainingArchive.readAttributeAsJson(KerasModel.HDF5_TRAINING_CONFIG_ATTRIBUTE, new String[0]);
            }
            return this;
        }

        public ModelBuilder weightsHdf5Filename(String weightsHdf5Filename) {
            this.weightsArchive = new Hdf5Archive(weightsHdf5Filename);
            return this;
        }

        public ModelBuilder enforceTrainingConfig(boolean enforceTrainingConfig) {
            this.enforceTrainingConfig = enforceTrainingConfig;
            return this;
        }

        public static ModelBuilder builder() {
            return new ModelBuilder();
        }

        public KerasModel buildModel() throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
            return new KerasModel(this);
        }

        public KerasSequentialModel buildSequential() throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
            return new KerasSequentialModel(this);
        }

        public String getModelJson() {
            return this.modelJson;
        }

        public String getModelYaml() {
            return this.modelYaml;
        }

        public String getTrainingJson() {
            return this.trainingJson;
        }

        public Hdf5Archive getWeightsArchive() {
            return this.weightsArchive;
        }

        public String getWeightsRoot() {
            return this.weightsRoot;
        }

        public Hdf5Archive getTrainingArchive() {
            return this.trainingArchive;
        }

        public boolean isEnforceTrainingConfig() {
            return this.enforceTrainingConfig;
        }

        public void setModelJson(String modelJson) {
            this.modelJson = modelJson;
        }

        public void setModelYaml(String modelYaml) {
            this.modelYaml = modelYaml;
        }

        public void setTrainingJson(String trainingJson) {
            this.trainingJson = trainingJson;
        }

        public void setWeightsArchive(Hdf5Archive weightsArchive) {
            this.weightsArchive = weightsArchive;
        }

        public void setWeightsRoot(String weightsRoot) {
            this.weightsRoot = weightsRoot;
        }

        public void setTrainingArchive(Hdf5Archive trainingArchive) {
            this.trainingArchive = trainingArchive;
        }

        public void setEnforceTrainingConfig(boolean enforceTrainingConfig) {
            this.enforceTrainingConfig = enforceTrainingConfig;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ModelBuilder)) {
                return false;
            }
            ModelBuilder other = (ModelBuilder)o;
            if (!other.canEqual(this)) {
                return false;
            }
            String this$modelJson = this.getModelJson();
            String other$modelJson = other.getModelJson();
            if (this$modelJson == null ? other$modelJson != null : !this$modelJson.equals(other$modelJson)) {
                return false;
            }
            String this$modelYaml = this.getModelYaml();
            String other$modelYaml = other.getModelYaml();
            if (this$modelYaml == null ? other$modelYaml != null : !this$modelYaml.equals(other$modelYaml)) {
                return false;
            }
            String this$trainingJson = this.getTrainingJson();
            String other$trainingJson = other.getTrainingJson();
            if (this$trainingJson == null ? other$trainingJson != null : !this$trainingJson.equals(other$trainingJson)) {
                return false;
            }
            Hdf5Archive this$weightsArchive = this.getWeightsArchive();
            Hdf5Archive other$weightsArchive = other.getWeightsArchive();
            if (this$weightsArchive == null ? other$weightsArchive != null : !this$weightsArchive.equals(other$weightsArchive)) {
                return false;
            }
            String this$weightsRoot = this.getWeightsRoot();
            String other$weightsRoot = other.getWeightsRoot();
            if (this$weightsRoot == null ? other$weightsRoot != null : !this$weightsRoot.equals(other$weightsRoot)) {
                return false;
            }
            Hdf5Archive this$trainingArchive = this.getTrainingArchive();
            Hdf5Archive other$trainingArchive = other.getTrainingArchive();
            if (this$trainingArchive == null ? other$trainingArchive != null : !this$trainingArchive.equals(other$trainingArchive)) {
                return false;
            }
            return this.isEnforceTrainingConfig() == other.isEnforceTrainingConfig();
        }

        protected boolean canEqual(Object other) {
            return other instanceof ModelBuilder;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            String $modelJson = this.getModelJson();
            result = result * 59 + ($modelJson == null ? 43 : $modelJson.hashCode());
            String $modelYaml = this.getModelYaml();
            result = result * 59 + ($modelYaml == null ? 43 : $modelYaml.hashCode());
            String $trainingJson = this.getTrainingJson();
            result = result * 59 + ($trainingJson == null ? 43 : $trainingJson.hashCode());
            Hdf5Archive $weightsArchive = this.getWeightsArchive();
            result = result * 59 + ($weightsArchive == null ? 43 : $weightsArchive.hashCode());
            String $weightsRoot = this.getWeightsRoot();
            result = result * 59 + ($weightsRoot == null ? 43 : $weightsRoot.hashCode());
            Hdf5Archive $trainingArchive = this.getTrainingArchive();
            result = result * 59 + ($trainingArchive == null ? 43 : $trainingArchive.hashCode());
            result = result * 59 + (this.isEnforceTrainingConfig() ? 79 : 97);
            return result;
        }

        public String toString() {
            return "KerasModel.ModelBuilder(modelJson=" + this.getModelJson() + ", modelYaml=" + this.getModelYaml() + ", trainingJson=" + this.getTrainingJson() + ", weightsArchive=" + this.getWeightsArchive() + ", weightsRoot=" + this.getWeightsRoot() + ", trainingArchive=" + this.getTrainingArchive() + ", enforceTrainingConfig=" + this.isEnforceTrainingConfig() + ")";
        }
    }
}

