/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ComputationGraphConfiguration
implements Serializable,
Cloneable {
    private static Logger log = LoggerFactory.getLogger(ComputationGraphConfiguration.class);
    protected Map<String, GraphVertex> vertices = new LinkedHashMap<String, GraphVertex>();
    protected Map<String, List<String>> vertexInputs = new LinkedHashMap<String, List<String>>();
    protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
    protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
    protected CacheMode cacheMode;
    protected List<String> networkInputs;
    protected List<String> networkOutputs;
    protected boolean pretrain = false;
    protected boolean backprop = true;
    protected BackpropType backpropType = BackpropType.Standard;
    protected int tbpttFwdLength = 20;
    protected int tbpttBackLength = 20;
    protected NeuralNetConfiguration defaultConfiguration;
    protected int iterationCount = 0;
    protected int epochCount = 0;
    protected int[] topologicalOrder;
    protected List<String> topologicalOrderStr;

    public String toYaml() {
        ObjectMapper mapper;
        ObjectMapper objectMapper = mapper = NeuralNetConfiguration.mapperYaml();
        synchronized (objectMapper) {
            try {
                return mapper.writeValueAsString((Object)this);
            }
            catch (JsonProcessingException e) {
                throw new RuntimeException(e);
            }
        }
    }

    public static ComputationGraphConfiguration fromYaml(String json) {
        ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
        try {
            return (ComputationGraphConfiguration)mapper.readValue(json, ComputationGraphConfiguration.class);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public String toJson() {
        ObjectMapper mapper;
        ObjectMapper objectMapper = mapper = NeuralNetConfiguration.mapper();
        synchronized (objectMapper) {
            try {
                return mapper.writeValueAsString((Object)this);
            }
            catch (JsonProcessingException e) {
                throw new RuntimeException(e);
            }
        }
    }

    public static ComputationGraphConfiguration fromJson(String json) {
        ComputationGraphConfiguration conf;
        ObjectMapper mapper = NeuralNetConfiguration.mapper();
        try {
            conf = (ComputationGraphConfiguration)mapper.readValue(json, ComputationGraphConfiguration.class);
        }
        catch (Exception e) {
            String msg = e.getMessage();
            if (msg != null && msg.contains("legacy")) {
                throw new RuntimeException("Error deserializing ComputationGraphConfiguration - configuration may have a custom layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)", e);
            }
            throw new RuntimeException(e);
        }
        boolean layerCount = false;
        Map<String, GraphVertex> vertexMap = conf.getVertices();
        JsonNode vertices = null;
        for (Map.Entry<String, GraphVertex> entry : vertexMap.entrySet()) {
            Layer layer;
            LayerVertex lv;
            if (!(entry.getValue() instanceof LayerVertex) || (lv = (LayerVertex)entry.getValue()).getLayerConf() == null || lv.getLayerConf().getLayer() == null || !((layer = lv.getLayerConf().getLayer()) instanceof BaseLayer) || ((BaseLayer)layer).getActivationFn() != null) continue;
            String layerName = layer.getLayerName();
            try {
                JsonNode layerNode;
                JsonNode activationFunction;
                JsonNode layerWrapperNode;
                JsonNode vertexNode;
                JsonNode layerVertexNode;
                if (vertices == null) {
                    JsonNode jsonNode = mapper.readTree(json);
                    vertices = jsonNode.get("vertices");
                }
                if ((layerVertexNode = (vertexNode = vertices.get(layerName)).get("LayerVertex")) == null || !layerVertexNode.has("layerConf") || !layerVertexNode.get("layerConf").has("layer") || (layerWrapperNode = layerVertexNode.get("layerConf").get("layer")) == null || layerWrapperNode.size() != 1 || (activationFunction = (layerNode = (JsonNode)layerWrapperNode.elements().next()).get("activationFunction")) == null) continue;
                IActivation ia = Activation.fromString((String)activationFunction.asText()).getActivationFunction();
                ((BaseLayer)layer).setActivationFn(ia);
            }
            catch (IOException e) {
                log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", (Throwable)e);
            }
        }
        return conf;
    }

    public String toString() {
        return this.toJson();
    }

    public ComputationGraphConfiguration clone() {
        ComputationGraphConfiguration conf = new ComputationGraphConfiguration();
        conf.vertices = new LinkedHashMap<String, GraphVertex>();
        for (Map.Entry<String, GraphVertex> entry : this.vertices.entrySet()) {
            conf.vertices.put(entry.getKey(), entry.getValue().clone());
        }
        conf.vertexInputs = new LinkedHashMap<String, List<String>>();
        for (Map.Entry<String, Object> entry : this.vertexInputs.entrySet()) {
            conf.vertexInputs.put(entry.getKey(), new ArrayList((Collection)entry.getValue()));
        }
        conf.networkInputs = new ArrayList<String>();
        conf.networkInputs = new ArrayList<String>(this.networkInputs);
        conf.networkOutputs = new ArrayList<String>(this.networkOutputs);
        conf.pretrain = this.pretrain;
        conf.backprop = this.backprop;
        conf.backpropType = this.backpropType;
        conf.tbpttFwdLength = this.tbpttFwdLength;
        conf.tbpttBackLength = this.tbpttBackLength;
        conf.defaultConfiguration = this.defaultConfiguration.clone();
        conf.trainingWorkspaceMode = this.trainingWorkspaceMode;
        conf.inferenceWorkspaceMode = this.inferenceWorkspaceMode;
        conf.cacheMode = this.cacheMode;
        conf.defaultConfiguration.cacheMode = this.cacheMode;
        return conf;
    }

    public void validate() {
        this.validate(false, false);
    }

    public void validate(boolean allowDisconnected, boolean allowNoOutput) {
        if (this.networkInputs == null || this.networkInputs.isEmpty()) {
            throw new IllegalStateException("Invalid configuration: network has no inputs. Use .addInputs(String...) to label (and give an ordering to) the network inputs");
        }
        if ((this.networkOutputs == null || this.networkOutputs.isEmpty()) && !allowNoOutput) {
            throw new IllegalStateException("Invalid configuration: network has no outputs.Use .setOutput(String...) to specify (and give an ordering to) the output vertices, or use allowNoOutputs(true) to disable this check");
        }
        for (String string : this.networkInputs) {
            if (!this.vertices.containsKey(string)) continue;
            throw new IllegalStateException("Invalid configuration: name \"" + string + "\" is present in both network inputs and graph vertices/layers");
        }
        for (Map.Entry entry : this.vertexInputs.entrySet()) {
            String string = (String)entry.getKey();
            if (entry.getValue() == null || ((List)entry.getValue()).isEmpty()) {
                throw new IllegalStateException("Invalid configuration: vertex \"" + string + "\" has no inputs");
            }
            for (String inputName : (List)entry.getValue()) {
                if (this.vertices.containsKey(inputName) || this.networkInputs.contains(inputName)) continue;
                throw new IllegalStateException("Invalid configuration: Vertex \"" + string + "\" has input \"" + inputName + "\" that does not exist");
            }
        }
        if (this.networkOutputs != null) {
            for (String string : this.networkOutputs) {
                if (this.vertices.containsKey(string)) continue;
                throw new IllegalStateException("Invalid configuration: Output name \"" + string + "\" is not a valid vertex");
            }
        }
        if (!allowDisconnected) {
            HashSet<String> seenAsInput = new HashSet<String>();
            seenAsInput.addAll(this.networkOutputs);
            for (Map.Entry<String, List<String>> entry : this.vertexInputs.entrySet()) {
                seenAsInput.addAll((Collection)entry.getValue());
            }
            HashSet<String> hashSet = new HashSet<String>();
            hashSet.addAll(this.networkInputs);
            hashSet.addAll(this.vertices.keySet());
            hashSet.removeAll(seenAsInput);
            if (!hashSet.isEmpty() && !allowNoOutput) {
                throw new IllegalStateException("Invalid configuration: disconnected vertices found - " + hashSet + ". Disconnected vertices are those that do not connect to either another vertex, and are also not a network output. To disable this error (i.e., allow network configurations with disconnected vertices) use GraphBuilder.allowDisconnected(true)");
            }
        }
    }

    public void addPreProcessors(InputType ... inputTypes) {
        this.getLayerActivationTypes(true, inputTypes);
    }

    public Map<String, InputType> getLayerActivationTypes(InputType ... inputTypes) {
        return this.getLayerActivationTypes(true, inputTypes);
    }

    public Map<String, InputType> getLayerActivationTypes(boolean addPreprocIfNecessary, InputType ... inputTypes) {
        if (inputTypes == null || inputTypes.length != this.networkInputs.size()) {
            throw new IllegalArgumentException("Invalid number of InputTypes: cannot add preprocessors if number of InputType objects differs from number of network inputs");
        }
        List<String> topologicalOrdering = this.topologicalOrdering();
        HashMap<String, InputType> vertexOutputs = new HashMap<String, InputType>();
        int currLayerIdx = -1;
        for (String s : topologicalOrdering) {
            int inputIdx = this.networkInputs.indexOf(s);
            if (inputIdx != -1) {
                vertexOutputs.put(s, inputTypes[inputIdx]);
                continue;
            }
            GraphVertex gv = this.vertices.get(s);
            ArrayList<InputType> inputTypeList = new ArrayList<InputType>();
            if (gv instanceof LayerVertex) {
                String in = this.vertexInputs.get(s).get(0);
                InputType layerInput = (InputType)vertexOutputs.get(in);
                inputTypeList.add(layerInput);
                LayerVertex lv = (LayerVertex)gv;
                Layer l = lv.getLayerConf().getLayer();
                if (lv.getPreProcessor() == null && addPreprocIfNecessary) {
                    InputPreProcessor preproc = l.getPreProcessorForInputType(layerInput);
                    lv.setPreProcessor(preproc);
                }
                InputType afterPreproc = layerInput;
                if (lv.getPreProcessor() != null) {
                    InputPreProcessor ip = lv.getPreProcessor();
                    afterPreproc = ip.getOutputType(layerInput);
                }
                l.setNIn(afterPreproc, false);
                ++currLayerIdx;
            } else {
                List<String> inputs = this.vertexInputs.get(s);
                if (inputs != null) {
                    for (String inputVertexName : inputs) {
                        inputTypeList.add((InputType)vertexOutputs.get(inputVertexName));
                    }
                }
            }
            InputType outputFromVertex = gv.getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()]));
            vertexOutputs.put(s, outputFromVertex);
        }
        return vertexOutputs;
    }

    private Map<String, List<String>> verticesOutputTo() {
        HashMap<String, List<String>> verticesOutputTo = new HashMap<String, List<String>>();
        for (Map.Entry<String, GraphVertex> entry : this.vertices.entrySet()) {
            String vertexName = entry.getKey();
            List<String> vertexInputNames = this.vertexInputs.get(vertexName);
            if (vertexInputNames == null) continue;
            for (String s : vertexInputNames) {
                ArrayList<String> list = (ArrayList<String>)verticesOutputTo.get(s);
                if (list == null) {
                    list = new ArrayList<String>();
                    verticesOutputTo.put(s, list);
                }
                list.add(vertexName);
            }
        }
        return verticesOutputTo;
    }

    private List<String> topologicalOrdering() {
        Map<String, List<String>> verticesOutputTo = this.verticesOutputTo();
        LinkedList<String> noIncomingEdges = new LinkedList<String>(this.networkInputs);
        ArrayList<String> topologicalOrdering = new ArrayList<String>();
        HashMap inputEdges = new HashMap();
        for (Map.Entry<String, List<String>> entry : this.vertexInputs.entrySet()) {
            inputEdges.put(entry.getKey(), new HashSet(entry.getValue()));
        }
        while (!noIncomingEdges.isEmpty()) {
            String next = noIncomingEdges.removeFirst();
            topologicalOrdering.add(next);
            List<String> list = verticesOutputTo.get(next);
            if (list == null || list.isEmpty()) continue;
            for (String s : list) {
                Set set = (Set)inputEdges.get(s);
                set.remove(next);
                if (!set.isEmpty()) continue;
                noIncomingEdges.add(s);
            }
        }
        for (Map.Entry<String, List<String>> entry : inputEdges.entrySet()) {
            Set set = (Set)((Object)entry.getValue());
            if (set == null || set.isEmpty()) continue;
            throw new IllegalStateException("Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle (cycle includes vertex \"" + entry.getKey() + "\")");
        }
        return topologicalOrdering;
    }

    public NetworkMemoryReport getMemoryReport(InputType ... inputTypes) {
        LinkedHashMap<String, MemoryReport> memoryReportMap = new LinkedHashMap<String, MemoryReport>();
        List<String> topologicalOrdering = this.topologicalOrdering();
        HashMap<String, InputType> vertexOutputs = new HashMap<String, InputType>();
        int currLayerIdx = -1;
        for (String s : topologicalOrdering) {
            int inputIdx = this.networkInputs.indexOf(s);
            if (inputIdx != -1) {
                vertexOutputs.put(s, inputTypes[inputIdx]);
                continue;
            }
            GraphVertex gv = this.vertices.get(s);
            ArrayList<Object> inputTypeList = new ArrayList<Object>();
            if (gv instanceof LayerVertex) {
                String in = this.vertexInputs.get(s).get(0);
                InputType layerInput = (InputType)vertexOutputs.get(in);
                inputTypeList.add(layerInput);
                ++currLayerIdx;
            } else {
                List<String> inputs = this.vertexInputs.get(s);
                if (inputs != null) {
                    for (String inputVertexName : inputs) {
                        inputTypeList.add(vertexOutputs.get(inputVertexName));
                    }
                }
            }
            InputType outputFromVertex = gv.getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()]));
            vertexOutputs.put(s, outputFromVertex);
            MemoryReport mr = gv.getMemoryReport(inputTypeList.toArray(new InputType[inputTypeList.size()]));
            memoryReportMap.put(s, mr);
        }
        return new NetworkMemoryReport(memoryReportMap, ComputationGraphConfiguration.class, "ComputationGraph", inputTypes);
    }

    public Map<String, GraphVertex> getVertices() {
        return this.vertices;
    }

    public Map<String, List<String>> getVertexInputs() {
        return this.vertexInputs;
    }

    public List<String> getNetworkInputs() {
        return this.networkInputs;
    }

    public List<String> getNetworkOutputs() {
        return this.networkOutputs;
    }

    public boolean isPretrain() {
        return this.pretrain;
    }

    public boolean isBackprop() {
        return this.backprop;
    }

    public BackpropType getBackpropType() {
        return this.backpropType;
    }

    public int getTbpttFwdLength() {
        return this.tbpttFwdLength;
    }

    public int getTbpttBackLength() {
        return this.tbpttBackLength;
    }

    public NeuralNetConfiguration getDefaultConfiguration() {
        return this.defaultConfiguration;
    }

    public int getIterationCount() {
        return this.iterationCount;
    }

    public int getEpochCount() {
        return this.epochCount;
    }

    public int[] getTopologicalOrder() {
        return this.topologicalOrder;
    }

    public List<String> getTopologicalOrderStr() {
        return this.topologicalOrderStr;
    }

    public void setVertices(Map<String, GraphVertex> vertices) {
        this.vertices = vertices;
    }

    public void setVertexInputs(Map<String, List<String>> vertexInputs) {
        this.vertexInputs = vertexInputs;
    }

    public void setNetworkInputs(List<String> networkInputs) {
        this.networkInputs = networkInputs;
    }

    public void setNetworkOutputs(List<String> networkOutputs) {
        this.networkOutputs = networkOutputs;
    }

    public void setPretrain(boolean pretrain) {
        this.pretrain = pretrain;
    }

    public void setBackprop(boolean backprop) {
        this.backprop = backprop;
    }

    public void setBackpropType(BackpropType backpropType) {
        this.backpropType = backpropType;
    }

    public void setTbpttFwdLength(int tbpttFwdLength) {
        this.tbpttFwdLength = tbpttFwdLength;
    }

    public void setTbpttBackLength(int tbpttBackLength) {
        this.tbpttBackLength = tbpttBackLength;
    }

    public void setDefaultConfiguration(NeuralNetConfiguration defaultConfiguration) {
        this.defaultConfiguration = defaultConfiguration;
    }

    public void setIterationCount(int iterationCount) {
        this.iterationCount = iterationCount;
    }

    public void setEpochCount(int epochCount) {
        this.epochCount = epochCount;
    }

    public void setTopologicalOrder(int[] topologicalOrder) {
        this.topologicalOrder = topologicalOrder;
    }

    public void setTopologicalOrderStr(List<String> topologicalOrderStr) {
        this.topologicalOrderStr = topologicalOrderStr;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ComputationGraphConfiguration)) {
            return false;
        }
        ComputationGraphConfiguration other = (ComputationGraphConfiguration)o;
        if (!other.canEqual(this)) {
            return false;
        }
        Map<String, GraphVertex> this$vertices = this.getVertices();
        Map<String, GraphVertex> other$vertices = other.getVertices();
        if (this$vertices == null ? other$vertices != null : !((Object)this$vertices).equals(other$vertices)) {
            return false;
        }
        Map<String, List<String>> this$vertexInputs = this.getVertexInputs();
        Map<String, List<String>> other$vertexInputs = other.getVertexInputs();
        if (this$vertexInputs == null ? other$vertexInputs != null : !((Object)this$vertexInputs).equals(other$vertexInputs)) {
            return false;
        }
        WorkspaceMode this$trainingWorkspaceMode = this.getTrainingWorkspaceMode();
        WorkspaceMode other$trainingWorkspaceMode = other.getTrainingWorkspaceMode();
        if (this$trainingWorkspaceMode == null ? other$trainingWorkspaceMode != null : !((Object)((Object)this$trainingWorkspaceMode)).equals((Object)other$trainingWorkspaceMode)) {
            return false;
        }
        WorkspaceMode this$inferenceWorkspaceMode = this.getInferenceWorkspaceMode();
        WorkspaceMode other$inferenceWorkspaceMode = other.getInferenceWorkspaceMode();
        if (this$inferenceWorkspaceMode == null ? other$inferenceWorkspaceMode != null : !((Object)((Object)this$inferenceWorkspaceMode)).equals((Object)other$inferenceWorkspaceMode)) {
            return false;
        }
        CacheMode this$cacheMode = this.getCacheMode();
        CacheMode other$cacheMode = other.getCacheMode();
        if (this$cacheMode == null ? other$cacheMode != null : !((Object)((Object)this$cacheMode)).equals((Object)other$cacheMode)) {
            return false;
        }
        List<String> this$networkInputs = this.getNetworkInputs();
        List<String> other$networkInputs = other.getNetworkInputs();
        if (this$networkInputs == null ? other$networkInputs != null : !((Object)this$networkInputs).equals(other$networkInputs)) {
            return false;
        }
        List<String> this$networkOutputs = this.getNetworkOutputs();
        List<String> other$networkOutputs = other.getNetworkOutputs();
        if (this$networkOutputs == null ? other$networkOutputs != null : !((Object)this$networkOutputs).equals(other$networkOutputs)) {
            return false;
        }
        if (this.isPretrain() != other.isPretrain()) {
            return false;
        }
        if (this.isBackprop() != other.isBackprop()) {
            return false;
        }
        BackpropType this$backpropType = this.getBackpropType();
        BackpropType other$backpropType = other.getBackpropType();
        if (this$backpropType == null ? other$backpropType != null : !((Object)((Object)this$backpropType)).equals((Object)other$backpropType)) {
            return false;
        }
        if (this.getTbpttFwdLength() != other.getTbpttFwdLength()) {
            return false;
        }
        if (this.getTbpttBackLength() != other.getTbpttBackLength()) {
            return false;
        }
        NeuralNetConfiguration this$defaultConfiguration = this.getDefaultConfiguration();
        NeuralNetConfiguration other$defaultConfiguration = other.getDefaultConfiguration();
        if (this$defaultConfiguration == null ? other$defaultConfiguration != null : !((Object)this$defaultConfiguration).equals(other$defaultConfiguration)) {
            return false;
        }
        if (this.getIterationCount() != other.getIterationCount()) {
            return false;
        }
        if (this.getEpochCount() != other.getEpochCount()) {
            return false;
        }
        if (!Arrays.equals(this.getTopologicalOrder(), other.getTopologicalOrder())) {
            return false;
        }
        List<String> this$topologicalOrderStr = this.getTopologicalOrderStr();
        List<String> other$topologicalOrderStr = other.getTopologicalOrderStr();
        return !(this$topologicalOrderStr == null ? other$topologicalOrderStr != null : !((Object)this$topologicalOrderStr).equals(other$topologicalOrderStr));
    }

    protected boolean canEqual(Object other) {
        return other instanceof ComputationGraphConfiguration;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Map<String, GraphVertex> $vertices = this.getVertices();
        result = result * 59 + ($vertices == null ? 43 : ((Object)$vertices).hashCode());
        Map<String, List<String>> $vertexInputs = this.getVertexInputs();
        result = result * 59 + ($vertexInputs == null ? 43 : ((Object)$vertexInputs).hashCode());
        WorkspaceMode $trainingWorkspaceMode = this.getTrainingWorkspaceMode();
        result = result * 59 + ($trainingWorkspaceMode == null ? 43 : ((Object)((Object)$trainingWorkspaceMode)).hashCode());
        WorkspaceMode $inferenceWorkspaceMode = this.getInferenceWorkspaceMode();
        result = result * 59 + ($inferenceWorkspaceMode == null ? 43 : ((Object)((Object)$inferenceWorkspaceMode)).hashCode());
        CacheMode $cacheMode = this.getCacheMode();
        result = result * 59 + ($cacheMode == null ? 43 : ((Object)((Object)$cacheMode)).hashCode());
        List<String> $networkInputs = this.getNetworkInputs();
        result = result * 59 + ($networkInputs == null ? 43 : ((Object)$networkInputs).hashCode());
        List<String> $networkOutputs = this.getNetworkOutputs();
        result = result * 59 + ($networkOutputs == null ? 43 : ((Object)$networkOutputs).hashCode());
        result = result * 59 + (this.isPretrain() ? 79 : 97);
        result = result * 59 + (this.isBackprop() ? 79 : 97);
        BackpropType $backpropType = this.getBackpropType();
        result = result * 59 + ($backpropType == null ? 43 : ((Object)((Object)$backpropType)).hashCode());
        result = result * 59 + this.getTbpttFwdLength();
        result = result * 59 + this.getTbpttBackLength();
        NeuralNetConfiguration $defaultConfiguration = this.getDefaultConfiguration();
        result = result * 59 + ($defaultConfiguration == null ? 43 : ((Object)$defaultConfiguration).hashCode());
        result = result * 59 + this.getIterationCount();
        result = result * 59 + this.getEpochCount();
        result = result * 59 + Arrays.hashCode(this.getTopologicalOrder());
        List<String> $topologicalOrderStr = this.getTopologicalOrderStr();
        result = result * 59 + ($topologicalOrderStr == null ? 43 : ((Object)$topologicalOrderStr).hashCode());
        return result;
    }

    private ComputationGraphConfiguration(Map<String, GraphVertex> vertices, Map<String, List<String>> vertexInputs, WorkspaceMode trainingWorkspaceMode, WorkspaceMode inferenceWorkspaceMode, CacheMode cacheMode, List<String> networkInputs, List<String> networkOutputs, boolean pretrain, boolean backprop, BackpropType backpropType, int tbpttFwdLength, int tbpttBackLength, NeuralNetConfiguration defaultConfiguration, int iterationCount, int epochCount, int[] topologicalOrder, List<String> topologicalOrderStr) {
        this.vertices = vertices;
        this.vertexInputs = vertexInputs;
        this.trainingWorkspaceMode = trainingWorkspaceMode;
        this.inferenceWorkspaceMode = inferenceWorkspaceMode;
        this.cacheMode = cacheMode;
        this.networkInputs = networkInputs;
        this.networkOutputs = networkOutputs;
        this.pretrain = pretrain;
        this.backprop = backprop;
        this.backpropType = backpropType;
        this.tbpttFwdLength = tbpttFwdLength;
        this.tbpttBackLength = tbpttBackLength;
        this.defaultConfiguration = defaultConfiguration;
        this.iterationCount = iterationCount;
        this.epochCount = epochCount;
        this.topologicalOrder = topologicalOrder;
        this.topologicalOrderStr = topologicalOrderStr;
    }

    public ComputationGraphConfiguration() {
    }

    public WorkspaceMode getTrainingWorkspaceMode() {
        return this.trainingWorkspaceMode;
    }

    public void setTrainingWorkspaceMode(WorkspaceMode trainingWorkspaceMode) {
        this.trainingWorkspaceMode = trainingWorkspaceMode;
    }

    public WorkspaceMode getInferenceWorkspaceMode() {
        return this.inferenceWorkspaceMode;
    }

    public void setInferenceWorkspaceMode(WorkspaceMode inferenceWorkspaceMode) {
        this.inferenceWorkspaceMode = inferenceWorkspaceMode;
    }

    public CacheMode getCacheMode() {
        return this.cacheMode;
    }

    public void setCacheMode(CacheMode cacheMode) {
        this.cacheMode = cacheMode;
    }

    public static class GraphBuilder {
        private static final int DEFAULT_TBPTT_LENGTH = 20;
        protected Map<String, GraphVertex> vertices = new LinkedHashMap<String, GraphVertex>();
        protected Map<String, List<String>> vertexInputs = new LinkedHashMap<String, List<String>>();
        protected List<String> networkInputs = new ArrayList<String>();
        protected List<InputType> networkInputTypes = new ArrayList<InputType>();
        protected List<String> networkOutputs = new ArrayList<String>();
        protected boolean pretrain = false;
        protected boolean backprop = true;
        protected BackpropType backpropType = BackpropType.Standard;
        protected int tbpttFwdLength = 20;
        protected int tbpttBackLength = 20;
        protected Map<String, InputPreProcessor> inputPreProcessors = new LinkedHashMap<String, InputPreProcessor>();
        protected NeuralNetConfiguration.Builder globalConfiguration;
        protected boolean allowDisconnected = false;
        protected boolean allowNoOutput = false;

        public GraphBuilder(NeuralNetConfiguration.Builder globalConfiguration) {
            this.globalConfiguration = globalConfiguration;
        }

        public GraphBuilder(ComputationGraphConfiguration newConf, NeuralNetConfiguration.Builder globalConfiguration) {
            ComputationGraphConfiguration clonedConf = newConf.clone();
            this.vertices = clonedConf.getVertices();
            this.vertexInputs = clonedConf.getVertexInputs();
            this.networkInputs = clonedConf.getNetworkInputs();
            this.networkOutputs = clonedConf.getNetworkOutputs();
            this.pretrain = clonedConf.isPretrain();
            this.backprop = clonedConf.isBackprop();
            this.backpropType = clonedConf.getBackpropType();
            this.tbpttFwdLength = clonedConf.getTbpttFwdLength();
            this.tbpttBackLength = clonedConf.getTbpttBackLength();
            this.globalConfiguration = globalConfiguration;
        }

        public GraphBuilder inputPreProcessor(String layer, InputPreProcessor processor) {
            this.inputPreProcessors.put(layer, processor);
            return this;
        }

        public GraphBuilder backprop(boolean backprop) {
            this.backprop = backprop;
            return this;
        }

        public GraphBuilder pretrain(boolean pretrain) {
            this.pretrain = pretrain;
            return this;
        }

        public GraphBuilder backpropType(BackpropType type) {
            this.backpropType = type;
            return this;
        }

        public GraphBuilder tBPTTForwardLength(int forwardLength) {
            this.tbpttFwdLength = forwardLength;
            return this;
        }

        public GraphBuilder tBPTTBackwardLength(int backwardLength) {
            this.tbpttBackLength = backwardLength;
            return this;
        }

        public GraphBuilder tBPTTLength(int tbpttLength) {
            this.tBPTTForwardLength(tbpttLength);
            return this.tBPTTBackwardLength(tbpttLength);
        }

        public GraphBuilder addLayer(String layerName, Layer layer, String ... layerInputs) {
            return this.addLayer(layerName, layer, (InputPreProcessor)null, layerInputs);
        }

        public GraphBuilder layer(int layerName, Layer layer, String ... layerInputs) {
            return this.addLayer(String.valueOf(layerName), layer, (InputPreProcessor)null, layerInputs);
        }

        public GraphBuilder layer(String layerName, Layer layer, String ... layerInputs) {
            return this.addLayer(layerName, layer, (InputPreProcessor)null, layerInputs);
        }

        public GraphBuilder addLayer(String layerName, Layer layer, InputPreProcessor preProcessor, String ... layerInputs) {
            NeuralNetConfiguration.Builder builder = this.globalConfiguration.clone();
            builder.layer(layer);
            this.addVertex(layerName, new LayerVertex(builder.build(), preProcessor), layerInputs);
            layer.setLayerName(layerName);
            return this;
        }

        public GraphBuilder layer(String layerName, Layer layer, InputPreProcessor preProcessor, String ... layerInputs) {
            return this.addLayer(layerName, layer, preProcessor, layerInputs);
        }

        public GraphBuilder removeVertex(String vertexName) {
            this.removeVertex(vertexName, true);
            return this;
        }

        public GraphBuilder removeVertex(String vertexName, boolean removeConnections) {
            this.vertices.remove(vertexName);
            this.vertexInputs.remove(vertexName);
            if (this.networkInputs.contains(vertexName)) {
                this.networkInputs.remove(vertexName);
            }
            if (removeConnections) {
                if (this.networkOutputs.contains(vertexName)) {
                    this.networkOutputs.remove(vertexName);
                }
                for (Map.Entry<String, List<String>> entry : this.vertexInputs.entrySet()) {
                    List<String> inputs = entry.getValue();
                    if (!inputs.contains(vertexName)) continue;
                    inputs.remove(vertexName);
                }
                if (this.inputPreProcessors.containsKey(vertexName)) {
                    this.inputPreProcessors.remove(vertexName);
                }
            }
            return this;
        }

        public GraphBuilder addInputs(String ... inputNames) {
            Collections.addAll(this.networkInputs, inputNames);
            return this;
        }

        public GraphBuilder addInputs(Collection<String> inputNames) {
            this.networkInputs.addAll(inputNames);
            return this;
        }

        public GraphBuilder setInputTypes(InputType ... inputTypes) {
            if (inputTypes != null && inputTypes.length > 0) {
                Collections.addAll(this.networkInputTypes, inputTypes);
            }
            return this;
        }

        public GraphBuilder setOutputs(String ... outputNames) {
            this.networkOutputs.clear();
            Collections.addAll(this.networkOutputs, outputNames);
            return this;
        }

        public GraphBuilder addVertex(String vertexName, GraphVertex vertex, String ... vertexInputs) {
            this.vertices.put(vertexName, vertex);
            if (vertex.maxVertexInputs() == 1 && vertexInputs != null && vertexInputs.length > 1) {
                String mergeName = vertexName + "-merge";
                this.addVertex(mergeName, new MergeVertex(), vertexInputs);
                this.vertexInputs.put(vertexName, Collections.singletonList(mergeName));
            } else if (vertexInputs != null) {
                this.vertexInputs.put(vertexName, Arrays.asList(vertexInputs));
            }
            return this;
        }

        public GraphBuilder allowDisconnected(boolean allowDisconnected) {
            this.allowDisconnected = allowDisconnected;
            return this;
        }

        public GraphBuilder allowNoOutput(boolean allowNoOutput) {
            this.allowNoOutput = allowNoOutput;
            return this;
        }

        public Map<String, InputType> getLayerActivationTypes() {
            ComputationGraphConfiguration conf;
            Preconditions.checkArgument((this.networkInputs != null && this.networkInputs.size() > 0 ? 1 : 0) != 0, (String)"Cannot calculate activation types if no inputs have been set (use addInputs(String...))");
            Preconditions.checkArgument((this.networkInputTypes != null && this.networkInputTypes.size() == this.networkInputs.size() ? 1 : 0) != 0, (String)"Cannot calculate layer activation types if network if network input types have notbeen set (use ");
            try {
                conf = this.buildConfig();
            }
            catch (Exception e) {
                throw new RuntimeException("Error calculating activation types for layers: error occured when constructing temporary ComputationGraphConfiguration)", e);
            }
            try {
                conf.validate(true, true);
            }
            catch (Exception e) {
                throw new RuntimeException("Error calculating activation types for layers: validation of temporary ComputationGraphConfiguration failed", e);
            }
            return conf.getLayerActivationTypes(true, this.networkInputTypes.toArray(new InputType[this.networkInputTypes.size()]));
        }

        private ComputationGraphConfiguration buildConfig() {
            if ((this.tbpttBackLength != 20 || this.tbpttFwdLength != 20) && this.backpropType != BackpropType.TruncatedBPTT) {
                log.warn("Truncated backpropagation through time lengths have been configured with values " + this.tbpttFwdLength + " and " + this.tbpttBackLength + " but backprop type is set to " + (Object)((Object)this.backpropType) + ". TBPTT configuration settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT");
            }
            ComputationGraphConfiguration conf = new ComputationGraphConfiguration();
            conf.backprop = this.backprop;
            conf.pretrain = this.pretrain;
            conf.backpropType = this.backpropType;
            conf.tbpttBackLength = this.tbpttBackLength;
            conf.tbpttFwdLength = this.tbpttFwdLength;
            conf.networkInputs = this.networkInputs;
            conf.networkOutputs = this.networkOutputs;
            conf.vertices = this.vertices;
            conf.vertexInputs = this.vertexInputs;
            conf.trainingWorkspaceMode = this.globalConfiguration.trainingWorkspaceMode;
            conf.inferenceWorkspaceMode = this.globalConfiguration.inferenceWorkspaceMode;
            conf.cacheMode = this.globalConfiguration.cacheMode;
            conf.defaultConfiguration = this.globalConfiguration.build();
            conf.getDefaultConfiguration().setPretrain(this.pretrain);
            for (Map.Entry<String, InputPreProcessor> entry : this.inputPreProcessors.entrySet()) {
                GraphVertex gv = this.vertices.get(entry.getKey());
                if (gv instanceof LayerVertex) {
                    LayerVertex lv = (LayerVertex)gv;
                    lv.setPreProcessor(entry.getValue());
                    continue;
                }
                throw new IllegalStateException("Invalid configuration: InputPreProcessor defined for GraphVertex \"" + entry.getKey() + "\", but this vertex is not a LayerVertex");
            }
            for (Map.Entry<String, Cloneable> entry : this.vertices.entrySet()) {
                LayerVertex lv;
                Layer l;
                if (!(entry.getValue() instanceof LayerVertex) || !((l = (lv = (LayerVertex)entry.getValue()).getLayerConf().getLayer()) instanceof BasePretrainNetwork)) continue;
                lv.getLayerConf().setPretrain(this.pretrain);
            }
            return conf;
        }

        public ComputationGraphConfiguration build() {
            ComputationGraphConfiguration conf = this.buildConfig();
            conf.validate(this.allowDisconnected, this.allowNoOutput);
            if (!this.networkInputTypes.isEmpty()) {
                conf.addPreProcessors(this.networkInputTypes.toArray(new InputType[this.networkInputs.size()]));
            }
            return conf;
        }

        public Map<String, GraphVertex> getVertices() {
            return this.vertices;
        }

        public Map<String, List<String>> getVertexInputs() {
            return this.vertexInputs;
        }

        public List<String> getNetworkInputs() {
            return this.networkInputs;
        }

        public List<InputType> getNetworkInputTypes() {
            return this.networkInputTypes;
        }

        public List<String> getNetworkOutputs() {
            return this.networkOutputs;
        }

        public boolean isPretrain() {
            return this.pretrain;
        }

        public boolean isBackprop() {
            return this.backprop;
        }

        public BackpropType getBackpropType() {
            return this.backpropType;
        }

        public int getTbpttFwdLength() {
            return this.tbpttFwdLength;
        }

        public int getTbpttBackLength() {
            return this.tbpttBackLength;
        }

        public Map<String, InputPreProcessor> getInputPreProcessors() {
            return this.inputPreProcessors;
        }

        public NeuralNetConfiguration.Builder getGlobalConfiguration() {
            return this.globalConfiguration;
        }

        public boolean isAllowDisconnected() {
            return this.allowDisconnected;
        }

        public boolean isAllowNoOutput() {
            return this.allowNoOutput;
        }

        public void setVertices(Map<String, GraphVertex> vertices) {
            this.vertices = vertices;
        }

        public void setVertexInputs(Map<String, List<String>> vertexInputs) {
            this.vertexInputs = vertexInputs;
        }

        public void setNetworkInputs(List<String> networkInputs) {
            this.networkInputs = networkInputs;
        }

        public void setNetworkInputTypes(List<InputType> networkInputTypes) {
            this.networkInputTypes = networkInputTypes;
        }

        public void setNetworkOutputs(List<String> networkOutputs) {
            this.networkOutputs = networkOutputs;
        }

        public void setPretrain(boolean pretrain) {
            this.pretrain = pretrain;
        }

        public void setBackprop(boolean backprop) {
            this.backprop = backprop;
        }

        public void setBackpropType(BackpropType backpropType) {
            this.backpropType = backpropType;
        }

        public void setTbpttFwdLength(int tbpttFwdLength) {
            this.tbpttFwdLength = tbpttFwdLength;
        }

        public void setTbpttBackLength(int tbpttBackLength) {
            this.tbpttBackLength = tbpttBackLength;
        }

        public void setInputPreProcessors(Map<String, InputPreProcessor> inputPreProcessors) {
            this.inputPreProcessors = inputPreProcessors;
        }

        public void setGlobalConfiguration(NeuralNetConfiguration.Builder globalConfiguration) {
            this.globalConfiguration = globalConfiguration;
        }

        public void setAllowDisconnected(boolean allowDisconnected) {
            this.allowDisconnected = allowDisconnected;
        }

        public void setAllowNoOutput(boolean allowNoOutput) {
            this.allowNoOutput = allowNoOutput;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof GraphBuilder)) {
                return false;
            }
            GraphBuilder other = (GraphBuilder)o;
            if (!other.canEqual(this)) {
                return false;
            }
            Map<String, GraphVertex> this$vertices = this.getVertices();
            Map<String, GraphVertex> other$vertices = other.getVertices();
            if (this$vertices == null ? other$vertices != null : !((Object)this$vertices).equals(other$vertices)) {
                return false;
            }
            Map<String, List<String>> this$vertexInputs = this.getVertexInputs();
            Map<String, List<String>> other$vertexInputs = other.getVertexInputs();
            if (this$vertexInputs == null ? other$vertexInputs != null : !((Object)this$vertexInputs).equals(other$vertexInputs)) {
                return false;
            }
            List<String> this$networkInputs = this.getNetworkInputs();
            List<String> other$networkInputs = other.getNetworkInputs();
            if (this$networkInputs == null ? other$networkInputs != null : !((Object)this$networkInputs).equals(other$networkInputs)) {
                return false;
            }
            List<InputType> this$networkInputTypes = this.getNetworkInputTypes();
            List<InputType> other$networkInputTypes = other.getNetworkInputTypes();
            if (this$networkInputTypes == null ? other$networkInputTypes != null : !((Object)this$networkInputTypes).equals(other$networkInputTypes)) {
                return false;
            }
            List<String> this$networkOutputs = this.getNetworkOutputs();
            List<String> other$networkOutputs = other.getNetworkOutputs();
            if (this$networkOutputs == null ? other$networkOutputs != null : !((Object)this$networkOutputs).equals(other$networkOutputs)) {
                return false;
            }
            if (this.isPretrain() != other.isPretrain()) {
                return false;
            }
            if (this.isBackprop() != other.isBackprop()) {
                return false;
            }
            BackpropType this$backpropType = this.getBackpropType();
            BackpropType other$backpropType = other.getBackpropType();
            if (this$backpropType == null ? other$backpropType != null : !((Object)((Object)this$backpropType)).equals((Object)other$backpropType)) {
                return false;
            }
            if (this.getTbpttFwdLength() != other.getTbpttFwdLength()) {
                return false;
            }
            if (this.getTbpttBackLength() != other.getTbpttBackLength()) {
                return false;
            }
            Map<String, InputPreProcessor> this$inputPreProcessors = this.getInputPreProcessors();
            Map<String, InputPreProcessor> other$inputPreProcessors = other.getInputPreProcessors();
            if (this$inputPreProcessors == null ? other$inputPreProcessors != null : !((Object)this$inputPreProcessors).equals(other$inputPreProcessors)) {
                return false;
            }
            NeuralNetConfiguration.Builder this$globalConfiguration = this.getGlobalConfiguration();
            NeuralNetConfiguration.Builder other$globalConfiguration = other.getGlobalConfiguration();
            if (this$globalConfiguration == null ? other$globalConfiguration != null : !((Object)this$globalConfiguration).equals(other$globalConfiguration)) {
                return false;
            }
            if (this.isAllowDisconnected() != other.isAllowDisconnected()) {
                return false;
            }
            return this.isAllowNoOutput() == other.isAllowNoOutput();
        }

        protected boolean canEqual(Object other) {
            return other instanceof GraphBuilder;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            Map<String, GraphVertex> $vertices = this.getVertices();
            result = result * 59 + ($vertices == null ? 43 : ((Object)$vertices).hashCode());
            Map<String, List<String>> $vertexInputs = this.getVertexInputs();
            result = result * 59 + ($vertexInputs == null ? 43 : ((Object)$vertexInputs).hashCode());
            List<String> $networkInputs = this.getNetworkInputs();
            result = result * 59 + ($networkInputs == null ? 43 : ((Object)$networkInputs).hashCode());
            List<InputType> $networkInputTypes = this.getNetworkInputTypes();
            result = result * 59 + ($networkInputTypes == null ? 43 : ((Object)$networkInputTypes).hashCode());
            List<String> $networkOutputs = this.getNetworkOutputs();
            result = result * 59 + ($networkOutputs == null ? 43 : ((Object)$networkOutputs).hashCode());
            result = result * 59 + (this.isPretrain() ? 79 : 97);
            result = result * 59 + (this.isBackprop() ? 79 : 97);
            BackpropType $backpropType = this.getBackpropType();
            result = result * 59 + ($backpropType == null ? 43 : ((Object)((Object)$backpropType)).hashCode());
            result = result * 59 + this.getTbpttFwdLength();
            result = result * 59 + this.getTbpttBackLength();
            Map<String, InputPreProcessor> $inputPreProcessors = this.getInputPreProcessors();
            result = result * 59 + ($inputPreProcessors == null ? 43 : ((Object)$inputPreProcessors).hashCode());
            NeuralNetConfiguration.Builder $globalConfiguration = this.getGlobalConfiguration();
            result = result * 59 + ($globalConfiguration == null ? 43 : ((Object)$globalConfiguration).hashCode());
            result = result * 59 + (this.isAllowDisconnected() ? 79 : 97);
            result = result * 59 + (this.isAllowNoOutput() ? 79 : 97);
            return result;
        }

        public String toString() {
            return "ComputationGraphConfiguration.GraphBuilder(vertices=" + this.getVertices() + ", vertexInputs=" + this.getVertexInputs() + ", networkInputs=" + this.getNetworkInputs() + ", networkInputTypes=" + this.getNetworkInputTypes() + ", networkOutputs=" + this.getNetworkOutputs() + ", pretrain=" + this.isPretrain() + ", backprop=" + this.isBackprop() + ", backpropType=" + (Object)((Object)this.getBackpropType()) + ", tbpttFwdLength=" + this.getTbpttFwdLength() + ", tbpttBackLength=" + this.getTbpttBackLength() + ", inputPreProcessors=" + this.getInputPreProcessors() + ", globalConfiguration=" + this.getGlobalConfiguration() + ", allowDisconnected=" + this.isAllowDisconnected() + ", allowNoOutput=" + this.isAllowNoOutput() + ")";
        }
    }
}

