/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.graph;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.ROC;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.graph.vertex.impl.InputVertex;
import org.deeplearning4j.nn.graph.vertex.impl.LayerVertex;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.NetworkUtils;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.util.OneTimeLogger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ComputationGraph
implements Serializable,
Model,
NeuralNetwork {
    private static final Logger log = LoggerFactory.getLogger(ComputationGraph.class);
    protected ComputationGraphConfiguration configuration;
    protected boolean initCalled = false;
    protected transient Solver solver;
    protected INDArray flattenedParams;
    protected transient INDArray flattenedGradients;
    protected Gradient gradient;
    protected double score;
    private boolean initDone = false;
    protected boolean clearTbpttState = true;
    public static final String WORKSPACE_CACHE = "LOOP_CACHE";
    public static final String WORKSPACE_EXTERNAL = "LOOP_EXTERNAL";
    public static final String WORKSPACE_FEED_FORWARD = "LOOP_FF";
    public static final String WORKSPACE_PRETRAIN = "LOOP_PTR";
    public static final String WORKSPACE_TBPTT = "LOOP_TBPTT";
    public static final String WORKSPACE_LSTM = "LOOP_LSTM";
    public static final WorkspaceConfiguration workspaceConfigurationFeedForward = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT).policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.OVER_TIME).build();
    public static final WorkspaceConfiguration workspaceConfigurationTBPTT = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT).policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.OVER_TIME).build();
    public static final WorkspaceConfiguration workspaceConfigurationLSTM = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT).policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.FIRST_LOOP).build();
    public static final WorkspaceConfiguration workspaceConfigurationExternal = WorkspaceConfiguration.builder().overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT).cyclesBeforeInitialization(3).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.OVER_TIME).build();
    public static final WorkspaceConfiguration workspaceConfigurationCache = WorkspaceConfiguration.builder().overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT).cyclesBeforeInitialization(3).policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.OVER_TIME).build();
    protected transient ThreadLocal<Long> lastEtlTime = new ThreadLocal();
    protected GraphVertex[] vertices;
    protected Map<String, GraphVertex> verticesMap;
    protected int[] topologicalOrder;
    protected org.deeplearning4j.nn.api.Layer[] layers;
    private int numInputArrays;
    private int numOutputArrays;
    private transient INDArray[] inputs;
    private transient INDArray[] labels;
    private transient INDArray[] inputMaskArrays;
    private transient INDArray[] labelMaskArrays;
    private NeuralNetConfiguration defaultConfiguration;
    private Collection<IterationListener> listeners = new ArrayList<IterationListener>();
    private Collection<TrainingListener> trainingListeners = new ArrayList<TrainingListener>();

    public ComputationGraph(ComputationGraphConfiguration configuration) {
        this.configuration = configuration;
        this.numInputArrays = configuration.getNetworkInputs().size();
        this.numOutputArrays = configuration.getNetworkOutputs().size();
        this.inputs = new INDArray[this.numInputArrays];
        this.labels = new INDArray[this.numOutputArrays];
        this.defaultConfiguration = configuration.getDefaultConfiguration();
    }

    public void setLastEtlTime(long time) {
        this.lastEtlTime.set(time);
    }

    public long getLastEtlTime() {
        Long time = this.lastEtlTime.get();
        return time == null ? 0L : time;
    }

    public void setCacheMode(CacheMode mode) {
        if (mode == null) {
            mode = CacheMode.NONE;
        }
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            layer.setCacheMode(mode);
        }
    }

    public ComputationGraphConfiguration getConfiguration() {
        return this.configuration;
    }

    public int getNumLayers() {
        return this.layers != null ? this.layers.length : 0;
    }

    public org.deeplearning4j.nn.api.Layer getLayer(int idx) {
        return this.layers[idx];
    }

    public org.deeplearning4j.nn.api.Layer[] getLayers() {
        return this.layers;
    }

    public org.deeplearning4j.nn.api.Layer getLayer(String name) {
        return this.verticesMap.get(name).getLayer();
    }

    public GraphVertex[] getVertices() {
        return this.vertices;
    }

    public GraphVertex getVertex(String name) {
        return this.verticesMap.get(name);
    }

    public int getNumInputArrays() {
        return this.numInputArrays;
    }

    public int getNumOutputArrays() {
        return this.numOutputArrays;
    }

    public void setInput(int inputNum, INDArray input) {
        if (this.inputs == null) {
            this.inputs = new INDArray[this.numInputArrays];
        }
        this.inputs[inputNum] = input;
    }

    public void setInputs(INDArray ... inputs) {
        if (inputs != null && inputs.length != this.numInputArrays) {
            throw new IllegalArgumentException("Invalid input array: network has " + this.numInputArrays + " inputs, but array is of length " + inputs.length);
        }
        this.inputs = inputs;
    }

    public INDArray getInput(int inputNum) {
        if (this.inputs == null) {
            return null;
        }
        return this.inputs[inputNum];
    }

    public INDArray[] getInputs() {
        return this.inputs;
    }

    public INDArray[] getInputMaskArrays() {
        return this.inputMaskArrays;
    }

    public INDArray[] getLabelMaskArrays() {
        return this.labelMaskArrays;
    }

    public void setLabel(int labelNum, INDArray label) {
        this.labels[labelNum] = label;
    }

    public void setLabels(INDArray ... labels) {
        if (labels != null && labels.length != this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid output array: network has " + this.numOutputArrays + " outputs, but array is of length " + labels.length);
        }
        this.labels = labels;
    }

    public void setGradientsAccumulator(GradientsAccumulator accumulator) {
        if (!this.initCalled) {
            this.init();
        }
        this.solver.getOptimizer().setGradientsAccumulator(accumulator);
    }

    @Override
    public void init() {
        this.init(null, false);
    }

    public void init(INDArray parameters, boolean cloneParametersArray) {
        String vertexName;
        boolean initializeParams;
        int i;
        if (this.initCalled) {
            return;
        }
        OneTimeLogger.info((Logger)log, (String)"Starting ComputationGraph with WorkspaceModes set to [training: {}; inference: {}]", (Object[])new Object[]{this.configuration.getTrainingWorkspaceMode(), this.configuration.getInferenceWorkspaceMode()});
        if (this.configuration.getCacheMode() == CacheMode.HOST) {
            workspaceConfigurationCache.setPolicyMirroring(MirroringPolicy.HOST_ONLY);
        }
        this.topologicalOrder = this.topologicalSortOrder();
        Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> configVertexMap = this.configuration.getVertices();
        List<String> networkInputNames = this.configuration.getNetworkInputs();
        Map<String, List<String>> vertexInputs = this.configuration.getVertexInputs();
        this.vertices = new GraphVertex[networkInputNames.size() + this.configuration.getVertices().size()];
        HashMap<String, Integer> allNamesReverse = new HashMap<String, Integer>();
        int vertexNumber = 0;
        for (String name : networkInputNames) {
            InputVertex gv = new InputVertex(this, name, vertexNumber, null);
            allNamesReverse.put(name, vertexNumber);
            this.vertices[vertexNumber++] = gv;
        }
        int numParams = 0;
        int[] numParamsForVertex = new int[this.topologicalOrder.length];
        for (i = 0; i < this.configuration.getNetworkInputs().size(); ++i) {
            numParamsForVertex[i] = 0;
        }
        for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeEntry : configVertexMap.entrySet()) {
            org.deeplearning4j.nn.conf.graph.GraphVertex n = nodeEntry.getValue();
            numParamsForVertex[i] = n.numParams(true);
            numParams += numParamsForVertex[i];
            ++i;
        }
        if (parameters != null) {
            if (!parameters.isRowVector()) {
                throw new IllegalArgumentException("Invalid parameters: should be a row vector");
            }
            if (parameters.length() != numParams) {
                throw new IllegalArgumentException("Invalid parameters: expected length " + numParams + ", got length " + parameters.length());
            }
            this.flattenedParams = cloneParametersArray ? parameters.dup() : parameters;
            initializeParams = false;
        } else if (numParams > 0) {
            this.flattenedParams = Nd4j.create((int)1, (int)numParams);
            initializeParams = true;
        } else {
            this.flattenedParams = null;
            initializeParams = false;
        }
        if (initializeParams) {
            Nd4j.getRandom().setSeed(this.conf().getSeed());
        }
        INDArray[] paramsViewForVertex = new INDArray[this.topologicalOrder.length];
        int paramOffsetSoFar = 0;
        i = 0;
        for (int vertexIdx : this.topologicalOrder) {
            int nParamsThisVertex = numParamsForVertex[vertexIdx];
            if (nParamsThisVertex != 0) {
                paramsViewForVertex[vertexIdx] = this.flattenedParams.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)paramOffsetSoFar, (int)(paramOffsetSoFar + nParamsThisVertex))});
            }
            ++i;
            paramOffsetSoFar += nParamsThisVertex;
        }
        int numLayers = 0;
        ArrayList<org.deeplearning4j.nn.api.Layer> tempLayerList = new ArrayList<org.deeplearning4j.nn.api.Layer>();
        this.defaultConfiguration.clearVariables();
        List<String> variables = this.defaultConfiguration.variables(false);
        for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeEntry : configVertexMap.entrySet()) {
            String name;
            org.deeplearning4j.nn.conf.graph.GraphVertex n = nodeEntry.getValue();
            GraphVertex gv = n.instantiate(this, name = nodeEntry.getKey(), vertexNumber, paramsViewForVertex[vertexNumber], initializeParams);
            if (gv == null) {
                throw new IllegalStateException("Encountered null layer/vertex during initialization for layer \"" + name + "\": " + n.getClass().getSimpleName() + " initialization returned null layer/vertex?");
            }
            if (gv.hasLayer()) {
                ++numLayers;
                org.deeplearning4j.nn.api.Layer l = gv.getLayer();
                tempLayerList.add(l);
                List<String> layerVariables = l.conf().variables();
                if (layerVariables != null) {
                    for (String s : layerVariables) {
                        variables.add(gv.getVertexName() + "_" + s);
                    }
                }
            }
            allNamesReverse.put(name, vertexNumber);
            this.vertices[vertexNumber++] = gv;
        }
        this.layers = tempLayerList.toArray(new org.deeplearning4j.nn.api.Layer[numLayers]);
        this.verticesMap = new HashMap<String, GraphVertex>();
        for (GraphVertex gv : this.vertices) {
            this.verticesMap.put(gv.getVertexName(), gv);
        }
        HashMap<String, ArrayList<String>> verticesOutputTo = new HashMap<String, ArrayList<String>>();
        for (GraphVertex gv : this.vertices) {
            vertexName = gv.getVertexName();
            List<String> vertexInputNames = vertexInputs.get(vertexName);
            if (vertexInputNames == null) continue;
            for (String s : vertexInputNames) {
                ArrayList<String> list = (ArrayList<String>)verticesOutputTo.get(s);
                if (list == null) {
                    list = new ArrayList<String>();
                    verticesOutputTo.put(s, list);
                }
                list.add(vertexName);
            }
        }
        for (GraphVertex gv : this.vertices) {
            vertexName = gv.getVertexName();
            int vertexIndex = gv.getVertexIndex();
            List<String> vertexInputNames = vertexInputs.get(vertexName);
            if (vertexInputNames == null) continue;
            VertexIndices[] inputIndices = new VertexIndices[vertexInputNames.size()];
            for (int j = 0; j < vertexInputNames.size(); ++j) {
                String inName = vertexInputNames.get(j);
                int inputVertexIndex = (Integer)allNamesReverse.get(inName);
                GraphVertex inputVertex = this.vertices[inputVertexIndex];
                List inputVertexOutputsTo = (List)verticesOutputTo.get(inName);
                int outputNumberOfInput = inputVertexOutputsTo.indexOf(vertexName);
                if (outputNumberOfInput == -1) {
                    throw new IllegalStateException("Could not find vertex " + vertexIndex + " in the list of outputs for vertex " + inputVertex + "; error in graph structure?");
                }
                inputIndices[j] = new VertexIndices(inputVertexIndex, outputNumberOfInput);
            }
            gv.setInputVertices(inputIndices);
        }
        for (GraphVertex gv : this.vertices) {
            vertexName = gv.getVertexName();
            List thisVertexOutputsTo = (List)verticesOutputTo.get(vertexName);
            if (thisVertexOutputsTo == null || thisVertexOutputsTo.isEmpty()) continue;
            VertexIndices[] outputIndices = new VertexIndices[thisVertexOutputsTo.size()];
            int j = 0;
            for (String s : thisVertexOutputsTo) {
                List<String> nextVertexInputNames = vertexInputs.get(s);
                int outputVertexInputNumber = nextVertexInputNames.indexOf(vertexName);
                int outputVertexIndex = (Integer)allNamesReverse.get(s);
                outputIndices[j++] = new VertexIndices(outputVertexIndex, outputVertexInputNumber);
            }
            gv.setOutputVertices(outputIndices);
        }
        for (String s : this.configuration.getNetworkOutputs()) {
            GraphVertex gv = this.verticesMap.get(s);
            gv.setOutputVertex(true);
        }
        if (this.solver == null) {
            try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                this.solver.initOptimizer();
            }
        }
        this.synchronizeIterEpochCounts();
        this.initCalled = true;
    }

    public void initGradientsView() {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            int i;
            if (!this.initCalled) {
                this.init();
            }
            int numParams = 0;
            int[] numParamsForVertex = new int[this.topologicalOrder.length];
            for (i = 0; i < this.configuration.getNetworkInputs().size(); ++i) {
                numParamsForVertex[i] = 0;
            }
            Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> configVertexMap = this.configuration.getVertices();
            for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeEntry : configVertexMap.entrySet()) {
                org.deeplearning4j.nn.conf.graph.GraphVertex n = nodeEntry.getValue();
                numParamsForVertex[i] = n.numParams(true);
                numParams += numParamsForVertex[i];
                ++i;
            }
            if (numParams > 0) {
                this.flattenedGradients = Nd4j.create((int)1, (int)numParams);
            }
            int paramOffsetSoFar = 0;
            i = 0;
            for (int vertexIdx : this.topologicalOrder) {
                int nParamsThisVertex = numParamsForVertex[vertexIdx];
                if (nParamsThisVertex != 0) {
                    INDArray gradientView = this.flattenedGradients.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)paramOffsetSoFar, (int)(paramOffsetSoFar + nParamsThisVertex))});
                    this.vertices[vertexIdx].setBackpropGradientsViewArray(gradientView);
                }
                ++i;
                paramOffsetSoFar += nParamsThisVertex;
            }
        }
    }

    public void pretrain(DataSetIterator iter) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs using a DataSetIterator");
        }
        this.pretrain(ComputationGraphUtil.toMultiDataSetIterator(iter));
    }

    public void pretrain(MultiDataSetIterator iter) {
        if (!this.configuration.isPretrain()) {
            return;
        }
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        for (int i = 0; i < this.topologicalOrder.length; ++i) {
            if (!this.vertices[i].hasLayer() || this.vertices[i].getLayer() instanceof IOutputLayer || !this.vertices[i].getLayer().isPretrainLayer()) continue;
            this.pretrainLayer(this.vertices[i].getVertexName(), iter);
        }
    }

    public void pretrainLayer(String layerName, DataSetIterator dataSetIterator) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs using a DataSetIterator");
        }
        this.pretrainLayer(layerName, ComputationGraphUtil.toMultiDataSetIterator(dataSetIterator));
    }

    public void pretrainLayer(String layerName, MultiDataSetIterator iter) {
        DummyWorkspace wsPTR;
        DummyWorkspace wsFF;
        if (!this.configuration.isPretrain()) {
            return;
        }
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        if (!this.verticesMap.containsKey(layerName)) {
            throw new IllegalStateException("Invalid vertex name: " + layerName);
        }
        if (!this.verticesMap.get(layerName).hasLayer()) {
            return;
        }
        int layerIndex = this.verticesMap.get(layerName).getVertexIndex();
        LinkedList<Integer> partialTopoSort = new LinkedList<Integer>();
        HashSet<Integer> seenSoFar = new HashSet<Integer>();
        partialTopoSort.add(this.topologicalOrder[layerIndex]);
        seenSoFar.add(this.topologicalOrder[layerIndex]);
        for (int j = layerIndex - 1; j >= 0; --j) {
            VertexIndices[] outputsTo = this.vertices[this.topologicalOrder[j]].getOutputVertices();
            boolean needed = false;
            for (VertexIndices vi : outputsTo) {
                if (!seenSoFar.contains(vi.getVertexIndex())) continue;
                needed = true;
                break;
            }
            if (!needed) continue;
            partialTopoSort.addFirst(this.topologicalOrder[j]);
            seenSoFar.add(this.topologicalOrder[j]);
        }
        int[] fwdPassOrder = new int[partialTopoSort.size()];
        int k = 0;
        for (Integer g : partialTopoSort) {
            fwdPassOrder[k++] = g;
        }
        GraphVertex gv = this.vertices[fwdPassOrder[fwdPassOrder.length - 1]];
        org.deeplearning4j.nn.api.Layer layer = gv.getLayer();
        if (!iter.hasNext() && iter.resetSupported()) {
            iter.reset();
        }
        DummyWorkspace workspace = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, WORKSPACE_EXTERNAL);
        DummyWorkspace cache = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationCache, WORKSPACE_CACHE);
        switch (this.configuration.getTrainingWorkspaceMode()) {
            case NONE: {
                wsFF = new DummyWorkspace();
                wsPTR = new DummyWorkspace();
                break;
            }
            case SINGLE: {
                wsFF = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WORKSPACE_EXTERNAL);
                wsPTR = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WORKSPACE_EXTERNAL);
                break;
            }
            case SEPARATE: {
                wsFF = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationFeedForward, WORKSPACE_FEED_FORWARD);
                wsPTR = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationFeedForward, WORKSPACE_PRETRAIN);
                break;
            }
            default: {
                throw new RuntimeException();
            }
        }
        while (iter.hasNext()) {
            org.nd4j.linalg.dataset.api.MultiDataSet multiDataSet = (org.nd4j.linalg.dataset.api.MultiDataSet)iter.next();
            MemoryWorkspace wsCache = cache.notifyScopeEntered();
            Throwable throwable = null;
            try {
                MemoryWorkspace ws = workspace.notifyScopeEntered();
                Throwable throwable2 = null;
                try {
                    MemoryWorkspace wP = wsPTR.notifyScopeEntered();
                    Throwable throwable3 = null;
                    try {
                        this.setInputs(multiDataSet.getFeatures());
                        for (int j = 0; j < fwdPassOrder.length - 1; ++j) {
                            try (MemoryWorkspace wF = wsFF.notifyScopeEntered();){
                                int vIdx;
                                GraphVertex current = this.vertices[fwdPassOrder[j]];
                                if (current.isInputVertex()) {
                                    VertexIndices[] inputsTo = current.getOutputVertices();
                                    INDArray input = this.inputs[current.getVertexIndex()];
                                    for (VertexIndices v : inputsTo) {
                                        vIdx = v.getVertexIndex();
                                        int vIdxInputNum = v.getVertexEdgeNumber();
                                        this.vertices[vIdx].setInput(vIdxInputNum, input.dup().leverageTo(WORKSPACE_PRETRAIN));
                                    }
                                    continue;
                                }
                                INDArray out = current.doForward(true);
                                VertexIndices[] outputsTo = current.getOutputVertices();
                                if (outputsTo == null) continue;
                                for (VertexIndices v : outputsTo) {
                                    vIdx = v.getVertexIndex();
                                    int inputNum = v.getVertexEdgeNumber();
                                    this.vertices[vIdx].setInput(inputNum, out);
                                }
                                continue;
                            }
                        }
                        layer.fit(gv.getInputs()[0]);
                        layer.conf().setPretrain(false);
                    }
                    catch (Throwable throwable4) {
                        throwable3 = throwable4;
                        throw throwable4;
                    }
                    finally {
                        if (wP == null) continue;
                        if (throwable3 != null) {
                            try {
                                wP.close();
                            }
                            catch (Throwable throwable5) {
                                throwable3.addSuppressed(throwable5);
                            }
                            continue;
                        }
                        wP.close();
                    }
                }
                catch (Throwable throwable6) {
                    throwable2 = throwable6;
                    throw throwable6;
                }
                finally {
                    if (ws == null) continue;
                    if (throwable2 != null) {
                        try {
                            ws.close();
                        }
                        catch (Throwable throwable7) {
                            throwable2.addSuppressed(throwable7);
                        }
                        continue;
                    }
                    ws.close();
                }
            }
            catch (Throwable throwable8) {
                throwable = throwable8;
                throw throwable8;
            }
            finally {
                if (wsCache == null) continue;
                if (throwable != null) {
                    try {
                        wsCache.close();
                    }
                    catch (Throwable throwable9) {
                        throwable.addSuppressed(throwable9);
                    }
                    continue;
                }
                wsCache.close();
            }
        }
    }

    @Override
    public void fit(DataSet dataSet) {
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs or outputs using a DataSet");
        }
        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            INDArray[] iNDArrayArray;
            INDArray[] fMask;
            INDArray[] iNDArrayArray2;
            if (dataSet.getFeaturesMaskArray() != null) {
                INDArray[] iNDArrayArray3 = new INDArray[1];
                iNDArrayArray2 = iNDArrayArray3;
                iNDArrayArray3[0] = dataSet.getFeaturesMaskArray();
            } else {
                iNDArrayArray2 = fMask = null;
            }
            if (dataSet.getLabelsMaskArray() != null) {
                INDArray[] iNDArrayArray4 = new INDArray[1];
                iNDArrayArray = iNDArrayArray4;
                iNDArrayArray4[0] = dataSet.getLabelsMaskArray();
            } else {
                iNDArrayArray = null;
            }
            INDArray[] lMask = iNDArrayArray;
            this.fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()}, fMask, lMask);
        } else {
            this.fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()});
        }
        if (hasMaskArrays) {
            this.clearLayerMaskArrays();
        }
        this.clearLayersStates();
    }

    @Override
    public void fit(DataSetIterator iterator) {
        DummyWorkspace cache;
        DataSetIterator dataSetIterator;
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs or outputs using a DataSetIterator");
        }
        boolean destructable = false;
        if (iterator.asyncSupported()) {
            dataSetIterator = new AsyncDataSetIterator(iterator, Math.min(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), this.configuration.getTrainingWorkspaceMode() != WorkspaceMode.NONE);
            destructable = true;
        } else {
            dataSetIterator = iterator;
        }
        if (!iterator.hasNext() && iterator.resetSupported()) {
            iterator.reset();
        }
        if (!this.trainingListeners.isEmpty()) {
            for (TrainingListener tl : this.trainingListeners) {
                tl.onEpochStart(this);
            }
        }
        DummyWorkspace workspace = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, WORKSPACE_EXTERNAL);
        Object object = cache = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationCache, WORKSPACE_CACHE);
        if (this.configuration.isBackprop()) {
            this.update(TaskUtils.buildTask((DataSetIterator)dataSetIterator));
            long time1 = System.currentTimeMillis();
            while (dataSetIterator.hasNext()) {
                DataSet next = (DataSet)dataSetIterator.next();
                long time2 = System.currentTimeMillis();
                this.lastEtlTime.set(time2 - time1);
                if (next.getFeatures() == null || next.getLabels() == null) break;
                boolean hasMaskArrays = next.hasMaskArrays();
                if (hasMaskArrays) {
                    INDArray[] iNDArrayArray;
                    INDArray[] fMask;
                    INDArray[] iNDArrayArray2;
                    if (next.getFeaturesMaskArray() != null) {
                        INDArray[] iNDArrayArray3 = new INDArray[1];
                        iNDArrayArray2 = iNDArrayArray3;
                        iNDArrayArray3[0] = next.getFeaturesMaskArray();
                    } else {
                        iNDArrayArray2 = fMask = null;
                    }
                    if (next.getLabelsMaskArray() != null) {
                        INDArray[] iNDArrayArray4 = new INDArray[1];
                        iNDArrayArray = iNDArrayArray4;
                        iNDArrayArray4[0] = next.getLabelsMaskArray();
                    } else {
                        iNDArrayArray = null;
                    }
                    INDArray[] lMask = iNDArrayArray;
                    this.setLayerMaskArrays(fMask, lMask);
                }
                if (this.configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                    INDArray[] iNDArrayArray;
                    INDArray[] iNDArrayArray5;
                    INDArray[] iNDArrayArray6 = new INDArray[]{next.getFeatures()};
                    INDArray[] iNDArrayArray7 = new INDArray[]{next.getLabels()};
                    if (hasMaskArrays) {
                        INDArray[] iNDArrayArray8 = new INDArray[1];
                        iNDArrayArray5 = iNDArrayArray8;
                        iNDArrayArray8[0] = next.getFeaturesMaskArray();
                    } else {
                        iNDArrayArray5 = null;
                    }
                    if (hasMaskArrays) {
                        INDArray[] iNDArrayArray9 = new INDArray[1];
                        iNDArrayArray = iNDArrayArray9;
                        iNDArrayArray9[0] = next.getLabelsMaskArray();
                    } else {
                        iNDArrayArray = null;
                    }
                    this.doTruncatedBPTT(iNDArrayArray6, iNDArrayArray7, iNDArrayArray5, iNDArrayArray);
                } else {
                    Throwable throwable;
                    this.setInput(0, next.getFeatures());
                    this.setLabel(0, next.getLabels());
                    if (this.solver == null) {
                        throwable = null;
                        try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                            this.solver = new Solver.Builder().configure(this.defaultConfiguration).listeners(this.listeners).model(this).build();
                        }
                        catch (Throwable throwable2) {
                            throwable = throwable2;
                            throw throwable2;
                        }
                    }
                    throwable = null;
                    try (MemoryWorkspace wsCache = cache.notifyScopeEntered();
                         MemoryWorkspace ws = workspace.notifyScopeEntered();){
                        this.solver.optimize();
                    }
                    catch (Throwable throwable3) {
                        throwable = throwable3;
                        throw throwable3;
                    }
                }
                if (hasMaskArrays) {
                    this.clearLayerMaskArrays();
                }
                time1 = System.currentTimeMillis();
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
        }
        if (!this.trainingListeners.isEmpty()) {
            for (TrainingListener tl : this.trainingListeners) {
                tl.onEpochEnd(this);
            }
        }
        this.clearLayersStates();
        if (destructable) {
            ((AsyncDataSetIterator)dataSetIterator).shutdown();
        }
        this.incrementEpochCount();
    }

    @Override
    public void fit(org.nd4j.linalg.dataset.api.MultiDataSet multiDataSet) {
        this.fit(multiDataSet.getFeatures(), multiDataSet.getLabels(), multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
        if (multiDataSet.hasMaskArrays()) {
            this.clearLayerMaskArrays();
        }
    }

    @Override
    public void fit(MultiDataSetIterator multi) {
        DummyWorkspace cache;
        MultiDataSetIterator multiDataSetIterator;
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        boolean destructable = false;
        if (multi.asyncSupported()) {
            multiDataSetIterator = new AsyncMultiDataSetIterator(multi, Math.max(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), this.configuration.getTrainingWorkspaceMode() != WorkspaceMode.NONE);
            destructable = true;
        } else {
            multiDataSetIterator = multi;
        }
        DummyWorkspace workspace = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, WORKSPACE_EXTERNAL);
        Object object = cache = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationCache, WORKSPACE_CACHE);
        if (this.configuration.isBackprop()) {
            long time1 = System.currentTimeMillis();
            while (multiDataSetIterator.hasNext()) {
                org.nd4j.linalg.dataset.api.MultiDataSet next = (org.nd4j.linalg.dataset.api.MultiDataSet)multiDataSetIterator.next();
                long time2 = System.currentTimeMillis();
                this.lastEtlTime.set(time2 - time1);
                if (next.getFeatures() == null || next.getLabels() == null) break;
                if (this.configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                    this.doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
                } else {
                    Throwable throwable;
                    boolean hasMaskArrays = next.hasMaskArrays();
                    if (hasMaskArrays) {
                        this.setLayerMaskArrays(next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
                    }
                    this.setInputs(next.getFeatures());
                    this.setLabels(next.getLabels());
                    if (this.solver == null) {
                        throwable = null;
                        try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                            this.solver = new Solver.Builder().configure(this.defaultConfiguration).listeners(this.listeners).model(this).build();
                        }
                        catch (Throwable throwable2) {
                            throwable = throwable2;
                            throw throwable2;
                        }
                    }
                    throwable = null;
                    try (MemoryWorkspace wsCache = cache.notifyScopeEntered();
                         MemoryWorkspace ws = workspace.notifyScopeEntered();){
                        this.solver.optimize();
                    }
                    catch (Throwable throwable3) {
                        throwable = throwable3;
                        throw throwable3;
                    }
                    if (hasMaskArrays) {
                        this.clearLayerMaskArrays();
                    }
                }
                Nd4j.getMemoryManager().invokeGcOccasionally();
                time1 = System.currentTimeMillis();
            }
        }
        this.clearLayersStates();
        if (destructable) {
            ((AsyncMultiDataSetIterator)multiDataSetIterator).shutdown();
        }
        this.incrementEpochCount();
    }

    protected void migrate(org.nd4j.linalg.dataset.api.MultiDataSet ds) {
        int i;
        if (ds.getFeatures() != null) {
            for (i = 0; i < ds.getFeatures().length; ++i) {
                if (ds.getFeatures()[i] == null || !ds.getFeatures()[i].isAttached()) continue;
                ds.getFeatures()[i] = ds.getFeatures()[i].migrate();
            }
        }
        if (ds.getFeaturesMaskArrays() != null) {
            for (i = 0; i < ds.getFeaturesMaskArrays().length; ++i) {
                if (ds.getFeaturesMaskArrays()[i] == null || !ds.getFeaturesMaskArrays()[i].isAttached()) continue;
                ds.getFeaturesMaskArrays()[i] = ds.getFeaturesMaskArrays()[i].migrate();
            }
        }
        if (ds.getLabels() != null) {
            for (i = 0; i < ds.getLabels().length; ++i) {
                if (ds.getLabels()[i] == null || !ds.getLabels()[i].isAttached()) continue;
                ds.getLabels()[i] = ds.getLabels()[i].migrate();
            }
        }
        if (ds.getLabelsMaskArrays() != null) {
            for (i = 0; i < ds.getLabelsMaskArrays().length; ++i) {
                if (ds.getLabelsMaskArrays()[i] == null || !ds.getLabelsMaskArrays()[i].isAttached()) continue;
                ds.getLabelsMaskArrays()[i] = ds.getLabelsMaskArrays()[i].migrate();
            }
        }
    }

    protected void migrate(DataSet ds) {
        if (ds.getFeatures() != null && ds.getFeatures().isAttached()) {
            ds.setFeatures(ds.getFeatures().migrate());
        }
        if (ds.getLabels() != null && ds.getLabels().isAttached()) {
            ds.setLabels(ds.getLabels().migrate());
        }
        if (ds.getFeaturesMaskArray() != null && ds.getFeaturesMaskArray().isAttached()) {
            ds.setFeaturesMaskArray(ds.getFeaturesMaskArray().migrate());
        }
        if (ds.getLabelsMaskArray() != null && ds.getLabelsMaskArray().isAttached()) {
            ds.setLabelsMaskArray(ds.getLabelsMaskArray().migrate());
        }
    }

    public void fit(INDArray[] inputs, INDArray[] labels) {
        this.fit(inputs, labels, null, null);
    }

    public void fit(INDArray[] inputs, INDArray[] labels, INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
        DummyWorkspace cache;
        if (this.numParams() == 0) {
            return;
        }
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        this.setInputs(inputs);
        this.setLabels(labels);
        this.setLayerMaskArrays(featureMaskArrays, labelMaskArrays);
        this.update(TaskUtils.buildTask((INDArray[])inputs, (INDArray[])labels));
        DummyWorkspace workspace = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, WORKSPACE_EXTERNAL);
        Object object = cache = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationCache, WORKSPACE_CACHE);
        if (this.configuration.isPretrain()) {
            SingletonMultiDataSetIterator iter = new SingletonMultiDataSetIterator((org.nd4j.linalg.dataset.api.MultiDataSet)new MultiDataSet(inputs, labels, featureMaskArrays, labelMaskArrays));
            this.pretrain((MultiDataSetIterator)iter);
        }
        if (this.configuration.isBackprop()) {
            if (this.configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                this.doTruncatedBPTT(inputs, labels, featureMaskArrays, labelMaskArrays);
            } else {
                Throwable throwable;
                if (this.solver == null) {
                    throwable = null;
                    try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                        this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                    }
                    catch (Throwable throwable2) {
                        throwable = throwable2;
                        throw throwable2;
                    }
                }
                throwable = null;
                try (MemoryWorkspace wsCache = cache.notifyScopeEntered();
                     MemoryWorkspace ws = workspace.notifyScopeEntered();){
                    this.solver.optimize();
                }
                catch (Throwable throwable3) {
                    throwable = throwable3;
                    throw throwable3;
                }
            }
        }
        if (featureMaskArrays != null || labelMaskArrays != null) {
            this.clearLayerMaskArrays();
        }
        this.clearLayersStates();
    }

    public int[] topologicalSortOrder() {
        if (this.topologicalOrder != null) {
            return this.topologicalOrder;
        }
        Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeMap = this.configuration.getVertices();
        List<String> networkInputNames = this.configuration.getNetworkInputs();
        int numVertices = networkInputNames.size() + this.configuration.getVertices().size();
        int[] out = new int[numVertices];
        int outCounter = 0;
        HashMap<Integer, String> vertexNamesMap = new HashMap<Integer, String>();
        HashMap<Object, Integer> vertexNamesMap2 = new HashMap<Object, Integer>();
        int i = 0;
        for (String string : this.configuration.getNetworkInputs()) {
            vertexNamesMap.put(i, string);
            vertexNamesMap2.put(string, i);
            ++i;
        }
        for (Map.Entry entry : nodeMap.entrySet()) {
            Iterator<Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex>> name = (String)entry.getKey();
            vertexNamesMap.put(i, (String)((Object)name));
            vertexNamesMap2.put(name, i);
            ++i;
        }
        HashMap inputEdges = new HashMap();
        HashMap<Integer, HashSet<Integer>> hashMap = new HashMap<Integer, HashSet<Integer>>();
        for (String string : this.configuration.getNetworkInputs()) {
            int n = (Integer)vertexNamesMap2.get(string);
            inputEdges.put(n, null);
        }
        for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> entry : nodeMap.entrySet()) {
            String string = entry.getKey();
            int idx = (Integer)vertexNamesMap2.get(string);
            List<String> inputsToThisVertex = this.configuration.getVertexInputs().get(string);
            if (inputsToThisVertex == null || inputsToThisVertex.isEmpty()) {
                inputEdges.put(idx, null);
                continue;
            }
            HashSet<Integer> inputSet = new HashSet<Integer>();
            for (String s : inputsToThisVertex) {
                Integer inputIdx = (Integer)vertexNamesMap2.get(s);
                inputSet.add(inputIdx);
                HashSet<Integer> outputSetForInputIdx = (HashSet<Integer>)hashMap.get(inputIdx);
                if (outputSetForInputIdx == null) {
                    outputSetForInputIdx = new HashSet<Integer>();
                    hashMap.put(inputIdx, outputSetForInputIdx);
                }
                outputSetForInputIdx.add(idx);
            }
            inputEdges.put(idx, inputSet);
        }
        LinkedList<Object> noIncomingEdges = new LinkedList<Object>();
        for (Map.Entry entry : inputEdges.entrySet()) {
            Set inputsFrom = (Set)entry.getValue();
            if (inputsFrom != null && !inputsFrom.isEmpty()) continue;
            noIncomingEdges.add(entry.getKey());
        }
        while (!noIncomingEdges.isEmpty()) {
            int n = (Integer)noIncomingEdges.removeFirst();
            out[outCounter++] = n;
            Set set = (Set)hashMap.get(n);
            if (set == null) continue;
            for (Integer v : set) {
                Set set2 = (Set)inputEdges.get(v);
                set2.remove(n);
                if (!set2.isEmpty()) continue;
                noIncomingEdges.add(v);
            }
        }
        for (Map.Entry entry : inputEdges.entrySet()) {
            Set set = (Set)entry.getValue();
            if (set == null || set.isEmpty()) continue;
            throw new IllegalStateException("Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle (cycle includes vertex \"" + (String)vertexNamesMap.get(entry.getKey()) + "\")");
        }
        return out;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * WARNING - void declaration
     */
    @Override
    public void computeGradientAndScore() {
        this.synchronizeIterEpochCounts();
        MemoryWorkspace wsExternal = null;
        boolean shouldCloseWorkspace = false;
        if (this.configuration.getTrainingWorkspaceMode() != WorkspaceMode.NONE && !(wsExternal = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, WORKSPACE_EXTERNAL)).isScopeActive()) {
            wsExternal.notifyScopeEntered();
            shouldCloseWorkspace = true;
        }
        try {
            void var7_20;
            Throwable throwable;
            MemoryWorkspace workspace;
            Map<String, INDArray> activations;
            if (this.configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                activations = this.rnnActivateUsingStoredState(this.inputs, true, true);
                if (!this.trainingListeners.isEmpty()) {
                    workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                    throwable = null;
                    try {
                        for (TrainingListener trainingListener : this.trainingListeners) {
                            trainingListener.onForwardPass((Model)this, activations);
                        }
                    }
                    catch (Throwable throwable2) {
                        throwable = throwable2;
                        throw throwable2;
                    }
                    finally {
                        if (workspace != null) {
                            if (throwable != null) {
                                try {
                                    workspace.close();
                                }
                                catch (Throwable throwable3) {
                                    throwable.addSuppressed(throwable3);
                                }
                            } else {
                                workspace.close();
                            }
                        }
                    }
                }
                this.calcBackpropGradients(true, new INDArray[0]);
            } else {
                activations = this.feedForward(true, true, false, false);
                if (!this.trainingListeners.isEmpty()) {
                    workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                    throwable = null;
                    try {
                        for (TrainingListener trainingListener : this.trainingListeners) {
                            trainingListener.onForwardPass((Model)this, activations);
                        }
                    }
                    catch (Throwable throwable4) {
                        throwable = throwable4;
                        throw throwable4;
                    }
                    finally {
                        if (workspace != null) {
                            if (throwable != null) {
                                try {
                                    workspace.close();
                                }
                                catch (Throwable throwable5) {
                                    throwable.addSuppressed(throwable5);
                                }
                            } else {
                                workspace.close();
                            }
                        }
                    }
                }
                this.calcBackpropGradients(false, new INDArray[0]);
            }
            double l1 = this.calcL1();
            double l2 = this.calcL2();
            this.score = 0.0;
            for (String s : this.configuration.getNetworkOutputs()) {
                GraphVertex gv = this.verticesMap.get(s);
                this.score += ((IOutputLayer)gv.getLayer()).computeScore(l1, l2, true);
                l1 = 0.0;
                l2 = 0.0;
            }
            if (!this.trainingListeners.isEmpty()) {
                try (MemoryWorkspace memoryWorkspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    for (TrainingListener tl : this.trainingListeners) {
                        tl.onBackwardPass(this);
                    }
                }
            }
            boolean bl = false;
            while (var7_20 < this.numOutputArrays) {
                this.getOutputLayer((int)var7_20).clearNoiseWeightParams();
                ++var7_20;
            }
        }
        finally {
            if (shouldCloseWorkspace) {
                wsExternal.notifyScopeLeft();
            }
        }
    }

    public Map<String, INDArray> feedForward(INDArray input, int layerTillIndex, boolean train) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot feedForward with single input for graph network with " + this.numInputArrays + " expected inputs");
        }
        this.setInput(0, input);
        return this.feedForward(train, layerTillIndex);
    }

    public Map<String, INDArray> feedForward(INDArray[] input, int layerTillIndex, boolean train, boolean clearInputs) {
        this.setInputs(input);
        return this.feedForward(train, false, false, clearInputs, layerTillIndex);
    }

    public Map<String, INDArray> feedForward(INDArray[] input, int layerTillIndex, boolean train) {
        return this.feedForward(input, train, true);
    }

    public Map<String, INDArray> feedForward(boolean train, int layerTillIndex) {
        return this.feedForward(train, false, false, true, layerTillIndex);
    }

    public Map<String, INDArray> feedForward(INDArray input, boolean train) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot feedForward with single input for graph network with " + this.numInputArrays + " expected inputs");
        }
        this.setInput(0, input);
        return this.feedForward(train);
    }

    public Map<String, INDArray> feedForward(INDArray[] input, boolean train) {
        return this.feedForward(input, train, true);
    }

    public Map<String, INDArray> feedForward(INDArray[] input, boolean train, boolean clearInputs) {
        this.setInputs(input);
        Map<String, INDArray> out = this.feedForward(train, false, false, clearInputs);
        if (!clearInputs) {
            for (org.deeplearning4j.nn.api.Layer l : this.layers) {
                l.migrateInput();
            }
        }
        return out;
    }

    public Map<String, INDArray> feedForward() {
        return this.feedForward(false);
    }

    public Map<String, INDArray> feedForward(boolean train) {
        return this.feedForward(train, false, false, true);
    }

    public Map<String, INDArray> feedForward(boolean train, boolean excludeOutputLayers) {
        return this.feedForward(train, excludeOutputLayers, false, true);
    }

    public Map<String, INDArray> feedForward(boolean train, boolean excludeOutputLayers, boolean includeNonLayerVertexActivations) {
        return this.feedForward(train, excludeOutputLayers, includeNonLayerVertexActivations, true);
    }

    protected Map<String, INDArray> feedForward(boolean train, boolean excludeOutputLayers, boolean includeNonLayerVertexActivations, boolean publicApi) {
        return this.feedForward(train, excludeOutputLayers, includeNonLayerVertexActivations, publicApi, -1);
    }

    protected Map<String, INDArray> feedForward(boolean train, boolean excludeOutputLayers, boolean includeNonLayerVertexActivations, boolean publicApi, int layerFeedForwardIdx) {
        DummyWorkspace workspace;
        HashMap<String, INDArray> layerActivations = new HashMap<String, INDArray>();
        WorkspaceMode wsm = this.configuration.getTrainingWorkspaceMode();
        switch (wsm) {
            case NONE: {
                workspace = new DummyWorkspace();
                break;
            }
            case SINGLE: {
                workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, WORKSPACE_EXTERNAL);
                break;
            }
            case SEPARATE: {
                workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationFeedForward, WORKSPACE_FEED_FORWARD);
                break;
            }
            default: {
                throw new RuntimeException("Unknown workspace mode: " + (Object)((Object)wsm));
            }
        }
        boolean wseOpenSingle = wsm == WorkspaceMode.SINGLE && Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(WORKSPACE_EXTERNAL);
        for (int i = 0; i < this.topologicalOrder.length; ++i) {
            GraphVertex current = this.vertices[this.topologicalOrder[i]];
            try (MemoryWorkspace ws = workspace.notifyScopeEntered();){
                MemoryWorkspace scopeTo;
                MemoryWorkspace scopeOut;
                Throwable throwable;
                MemoryWorkspace wsB;
                int vIdx;
                if (current.isInputVertex()) {
                    VertexIndices[] inputsTo = current.getOutputVertices();
                    INDArray input = this.inputs[current.getVertexIndex()].leverageOrDetach(WORKSPACE_EXTERNAL);
                    layerActivations.put(current.getVertexName(), input);
                    for (VertexIndices v : inputsTo) {
                        vIdx = v.getVertexIndex();
                        int vIdxInputNum = v.getVertexEdgeNumber();
                        if (Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(WORKSPACE_EXTERNAL) && Nd4j.getMemoryManager().getCurrentWorkspace() != Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WORKSPACE_EXTERNAL)) {
                            wsB = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WORKSPACE_EXTERNAL).notifyScopeBorrowed();
                            throwable = null;
                            try {
                                this.vertices[vIdx].setInput(vIdxInputNum, input);
                            }
                            catch (Throwable throwable2) {
                                throwable = throwable2;
                                throw throwable2;
                            }
                            finally {
                                if (wsB != null) {
                                    if (throwable != null) {
                                        try {
                                            wsB.close();
                                        }
                                        catch (Throwable throwable3) {
                                            throwable.addSuppressed(throwable3);
                                        }
                                    } else {
                                        wsB.close();
                                    }
                                }
                            }
                        } else {
                            this.vertices[vIdx].setInput(vIdxInputNum, input);
                        }
                        if (publicApi || wsm == WorkspaceMode.SINGLE && !wseOpenSingle) {
                            scopeOut = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                            throwable = null;
                            try {
                                this.vertices[vIdx].migrateInput();
                                continue;
                            }
                            catch (Throwable throwable4) {
                                throwable = throwable4;
                                throw throwable4;
                            }
                            finally {
                                if (scopeOut != null) {
                                    if (throwable != null) {
                                        try {
                                            scopeOut.close();
                                        }
                                        catch (Throwable throwable5) {
                                            throwable.addSuppressed(throwable5);
                                        }
                                    } else {
                                        scopeOut.close();
                                    }
                                }
                            }
                        }
                        if (Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(WORKSPACE_EXTERNAL)) {
                            scopeTo = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WORKSPACE_EXTERNAL).notifyScopeBorrowed();
                            throwable = null;
                            try {
                                this.vertices[vIdx].migrateInput();
                                continue;
                            }
                            catch (Throwable throwable6) {
                                throwable = throwable6;
                                throw throwable6;
                            }
                            finally {
                                if (scopeTo != null) {
                                    if (throwable != null) {
                                        try {
                                            scopeTo.close();
                                        }
                                        catch (Throwable throwable7) {
                                            throwable.addSuppressed(throwable7);
                                        }
                                    } else {
                                        scopeTo.close();
                                    }
                                }
                            }
                        }
                        scopeTo = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
                        throwable = null;
                        try {
                            this.vertices[vIdx].migrateInput();
                        }
                        catch (Throwable throwable8) {
                            throwable = throwable8;
                            throw throwable8;
                        }
                        finally {
                            if (scopeTo != null) {
                                if (throwable != null) {
                                    try {
                                        scopeTo.close();
                                    }
                                    catch (Throwable throwable9) {
                                        throwable.addSuppressed(throwable9);
                                    }
                                } else {
                                    scopeTo.close();
                                }
                            }
                        }
                    }
                } else {
                    VertexIndices[] outputsTo;
                    if (excludeOutputLayers && current.isOutputVertex() && current.hasLayer() && current.getLayer() instanceof IOutputLayer) continue;
                    INDArray out = publicApi ? current.doForward(train).detach() : (wsm == WorkspaceMode.SINGLE && !wseOpenSingle ? current.doForward(train).detach() : current.doForward(train).leverageOrDetach(WORKSPACE_EXTERNAL));
                    if (includeNonLayerVertexActivations || current.hasLayer() || current.isOutputVertex()) {
                        layerActivations.put(current.getVertexName(), out);
                    }
                    if ((outputsTo = current.getOutputVertices()) != null) {
                        for (VertexIndices v : outputsTo) {
                            vIdx = v.getVertexIndex();
                            int inputNum = v.getVertexEdgeNumber();
                            if (Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(WORKSPACE_EXTERNAL) && Nd4j.getMemoryManager().getCurrentWorkspace() != Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WORKSPACE_EXTERNAL)) {
                                wsB = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WORKSPACE_EXTERNAL).notifyScopeBorrowed();
                                throwable = null;
                                try {
                                    this.vertices[vIdx].setInput(inputNum, out);
                                }
                                catch (Throwable throwable10) {
                                    throwable = throwable10;
                                    throw throwable10;
                                }
                                finally {
                                    if (wsB != null) {
                                        if (throwable != null) {
                                            try {
                                                wsB.close();
                                            }
                                            catch (Throwable throwable11) {
                                                throwable.addSuppressed(throwable11);
                                            }
                                        } else {
                                            wsB.close();
                                        }
                                    }
                                }
                            } else {
                                this.vertices[vIdx].setInput(inputNum, out);
                            }
                            if (publicApi || wsm == WorkspaceMode.SINGLE && !wseOpenSingle) {
                                scopeOut = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                                throwable = null;
                                try {
                                    this.vertices[vIdx].migrateInput();
                                    continue;
                                }
                                catch (Throwable throwable12) {
                                    throwable = throwable12;
                                    throw throwable12;
                                }
                                finally {
                                    if (scopeOut != null) {
                                        if (throwable != null) {
                                            try {
                                                scopeOut.close();
                                            }
                                            catch (Throwable throwable13) {
                                                throwable.addSuppressed(throwable13);
                                            }
                                        } else {
                                            scopeOut.close();
                                        }
                                    }
                                }
                            }
                            if (Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(WORKSPACE_EXTERNAL)) {
                                scopeTo = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WORKSPACE_EXTERNAL).notifyScopeBorrowed();
                                throwable = null;
                                try {
                                    this.vertices[vIdx].migrateInput();
                                    continue;
                                }
                                catch (Throwable throwable14) {
                                    throwable = throwable14;
                                    throw throwable14;
                                }
                                finally {
                                    if (scopeTo != null) {
                                        if (throwable != null) {
                                            try {
                                                scopeTo.close();
                                            }
                                            catch (Throwable throwable15) {
                                                throwable.addSuppressed(throwable15);
                                            }
                                        } else {
                                            scopeTo.close();
                                        }
                                    }
                                }
                            }
                            scopeTo = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
                            throwable = null;
                            try {
                                this.vertices[vIdx].migrateInput();
                            }
                            catch (Throwable throwable16) {
                                throwable = throwable16;
                                throw throwable16;
                            }
                            finally {
                                if (scopeTo != null) {
                                    if (throwable != null) {
                                        try {
                                            scopeTo.close();
                                        }
                                        catch (Throwable throwable17) {
                                            throwable.addSuppressed(throwable17);
                                        }
                                    } else {
                                        scopeTo.close();
                                    }
                                }
                            }
                        }
                    }
                }
            }
            if (layerFeedForwardIdx > 0 && current.getVertexIndex() == layerFeedForwardIdx) break;
        }
        if (!train && wsm == WorkspaceMode.SEPARATE) {
            Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WORKSPACE_FEED_FORWARD).initializeWorkspace();
        }
        if (publicApi) {
            this.clearLayersStates();
        }
        return layerActivations;
    }

    public INDArray[] output(INDArray ... input) {
        return this.output(false, input, this.inputMaskArrays, this.labelMaskArrays);
    }

    public INDArray outputSingle(INDArray ... input) {
        return this.outputSingle(false, input);
    }

    public INDArray[] output(boolean train, INDArray ... input) {
        return this.output(train, input, this.inputMaskArrays, this.labelMaskArrays);
    }

    public INDArray[] output(boolean train, @NonNull INDArray[] input, INDArray[] inputMasks) {
        if (input == null) {
            throw new NullPointerException("input");
        }
        return this.output(train, input, inputMasks, (INDArray[])null);
    }

    /*
     * Loose catch block
     */
    public INDArray[] output(boolean train, @NonNull INDArray[] input, INDArray[] inputMasks, INDArray[] labelMasks) {
        if (input == null) {
            throw new NullPointerException("input");
        }
        this.setLayerMaskArrays(inputMasks, labelMasks);
        WorkspaceMode cMode = this.configuration.getTrainingWorkspaceMode();
        this.configuration.setTrainingWorkspaceMode(this.configuration.getInferenceWorkspaceMode());
        DummyWorkspace workspace = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, WORKSPACE_EXTERNAL);
        try {
            try (MemoryWorkspace wsE = workspace.notifyScopeEntered();){
                INDArray[] tmp = this.silentOutput(train, input);
                for (int x = 0; x < tmp.length; ++x) {
                    tmp[x] = tmp[x].detach();
                }
                this.configuration.setTrainingWorkspaceMode(cMode);
                INDArray[] iNDArrayArray = tmp;
                return iNDArrayArray;
            }
            {
                catch (Throwable throwable) {
                    throw throwable;
                }
            }
        }
        finally {
            this.clearLayersStates();
            this.clearLayerMaskArrays();
        }
    }

    protected INDArray[] silentOutput(boolean train, INDArray ... input) {
        this.setInputs(input);
        Map<String, INDArray> activations = this.feedForward(train, false, false, false);
        INDArray[] outputs = new INDArray[this.numOutputArrays];
        int i = 0;
        for (String s : this.configuration.getNetworkOutputs()) {
            outputs[i++] = activations.get(s);
        }
        return outputs;
    }

    public INDArray outputSingle(boolean train, INDArray ... input) {
        return this.outputSingle(train, true, input);
    }

    public INDArray outputSingle(boolean train, boolean clearInputs, INDArray ... input) {
        if (this.numOutputArrays != 1) {
            throw new IllegalStateException("Cannot use outputSingle with ComputationGraph that does not have exactly 1 output. nOutputs: " + this.numOutputArrays);
        }
        return this.output(train, clearInputs, input)[0];
    }

    public INDArray[] output(boolean train, boolean clearInputs, INDArray ... input) {
        this.setInputs(input);
        if (clearInputs) {
            WorkspaceMode cMode = this.configuration.getTrainingWorkspaceMode();
            this.configuration.setTrainingWorkspaceMode(this.configuration.getInferenceWorkspaceMode());
            DummyWorkspace workspace = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, WORKSPACE_EXTERNAL);
            try (MemoryWorkspace wsE = workspace.notifyScopeEntered();){
                INDArray[] tmp = this.silentOutput(train, input);
                for (int x = 0; x < tmp.length; ++x) {
                    tmp[x] = tmp[x].detach();
                }
                this.configuration.setTrainingWorkspaceMode(cMode);
                this.clearLayersStates();
                INDArray[] iNDArrayArray = tmp;
                return iNDArrayArray;
            }
        }
        Map<String, INDArray> activations = this.feedForward(train, false, false, false);
        INDArray[] outputs = new INDArray[this.numOutputArrays];
        int i = 0;
        for (String s : this.configuration.getNetworkOutputs()) {
            outputs[i++] = activations.get(s).detach();
        }
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            layer.migrateInput();
        }
        for (Serializable serializable : this.vertices) {
            serializable.migrateInput();
        }
        return outputs;
    }

    public Gradient backpropGradient(INDArray ... epsilons) {
        if (epsilons == null || epsilons.length != this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid input: must have epsilons length equal to number of output arrays");
        }
        this.calcBackpropGradients(this.configuration.getBackpropType() == BackpropType.TruncatedBPTT, epsilons);
        return this.gradient;
    }

    protected void calcBackpropGradients(boolean truncatedBPTT, INDArray ... externalEpsilons) {
        DummyWorkspace workspace;
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        WorkspaceMode wsm = this.configuration.getTrainingWorkspaceMode();
        switch (wsm) {
            case NONE: {
                workspace = new DummyWorkspace();
                break;
            }
            case SINGLE: {
                workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, WORKSPACE_EXTERNAL);
                break;
            }
            case SEPARATE: {
                workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationFeedForward, WORKSPACE_FEED_FORWARD);
                break;
            }
            default: {
                throw new RuntimeException();
            }
        }
        LinkedList<Triple> gradients = new LinkedList<Triple>();
        boolean wsExternalActive = false;
        if (wsm == WorkspaceMode.SINGLE) {
            wsExternalActive = Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(WORKSPACE_EXTERNAL);
        }
        boolean[] setVertexEpsilon = new boolean[this.topologicalOrder.length];
        for (int i = this.topologicalOrder.length - 1; i >= 0; --i) {
            try (MemoryWorkspace ws = workspace.notifyScopeEntered();){
                GraphVertex current = this.vertices[this.topologicalOrder[i]];
                if (current.isInputVertex()) continue;
                if (current.hasLayer() && current.getLayer() instanceof FrozenLayer) break;
                if (current.isOutputVertex()) {
                    int thisOutputNumber = this.configuration.getNetworkOutputs().indexOf(current.getVertexName());
                    if (current.getLayer() instanceof IOutputLayer) {
                        IOutputLayer outputLayer = (IOutputLayer)current.getLayer();
                        INDArray currLabels = this.labels[thisOutputNumber];
                        outputLayer.setLabels(currLabels);
                    } else {
                        if ((externalEpsilons == null || externalEpsilons.length == 0) && this.labels[thisOutputNumber] != null) {
                            throw new DL4JException("Layer \"" + current.getVertexName() + "\" of type " + current.getLayer().getClass().getSimpleName() + " is set as network output (but isn't an IOutputLayer). Only IOutputLayer layers can be fit via backprop with a labels array. ");
                        }
                        current.setEpsilon(externalEpsilons[thisOutputNumber]);
                        setVertexEpsilon[this.topologicalOrder[i]] = true;
                    }
                }
                Pair<Gradient, INDArray[]> pair = current.doBackward(truncatedBPTT);
                INDArray[] epsilons = (INDArray[])pair.getSecond();
                for (int x = 0; x < epsilons.length; ++x) {
                    if (epsilons[x] == null) continue;
                    if (wsm == WorkspaceMode.SEPARATE) {
                        epsilons[x] = epsilons[x].leverageOrDetach(WORKSPACE_EXTERNAL);
                        continue;
                    }
                    if (wsm != WorkspaceMode.SINGLE) continue;
                    epsilons[x] = wsExternalActive ? epsilons[x].leverageTo(WORKSPACE_EXTERNAL) : epsilons[x].detach();
                }
                VertexIndices[] inputVertices = current.getInputVertices();
                if (inputVertices != null) {
                    int j = 0;
                    for (VertexIndices v : inputVertices) {
                        GraphVertex gv = this.vertices[v.getVertexIndex()];
                        if (setVertexEpsilon[gv.getVertexIndex()]) {
                            Throwable throwable;
                            INDArray currentEps = gv.getEpsilon().leverageOrDetach(WORKSPACE_EXTERNAL);
                            if (wsm == WorkspaceMode.NONE) {
                                gv.setEpsilon(currentEps.add(epsilons[j++]));
                            } else if (wsm == WorkspaceMode.SINGLE && !wsExternalActive) {
                                throwable = null;
                                try (MemoryWorkspace wsOut = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                                    gv.setEpsilon(currentEps.add(epsilons[j++]));
                                }
                                catch (Throwable throwable2) {
                                    throwable = throwable2;
                                    throw throwable2;
                                }
                            } else {
                                throwable = null;
                                try (MemoryWorkspace wsB = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WORKSPACE_EXTERNAL).notifyScopeBorrowed();){
                                    gv.setEpsilon(currentEps.add(epsilons[j++]));
                                }
                                catch (Throwable throwable3) {
                                    throwable = throwable3;
                                    throw throwable3;
                                }
                            }
                        } else {
                            gv.setEpsilon(epsilons[j++]);
                        }
                        setVertexEpsilon[gv.getVertexIndex()] = true;
                    }
                }
                if (pair.getFirst() == null) continue;
                Gradient g = (Gradient)pair.getFirst();
                Map<String, INDArray> map = g.gradientForVariable();
                LinkedList<Triple> tempList = new LinkedList<Triple>();
                for (Map.Entry<String, INDArray> entry : map.entrySet()) {
                    String origName = entry.getKey();
                    String newName = current.getVertexName() + "_" + origName;
                    tempList.addFirst(new Triple((Object)newName, (Object)entry.getValue(), (Object)g.flatteningOrderForVariable(origName)));
                }
                for (Triple t : tempList) {
                    gradients.addFirst(t);
                }
                continue;
            }
        }
        DefaultGradient gradient = new DefaultGradient(this.flattenedGradients);
        for (Triple t : gradients) {
            gradient.setGradientFor((String)t.getFirst(), (INDArray)t.getSecond(), (Character)t.getThird());
        }
        if (this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.SEPARATE) {
            Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WORKSPACE_FEED_FORWARD).initializeWorkspace();
        }
        this.gradient = gradient;
        if (truncatedBPTT && this.clearTbpttState) {
            this.rnnClearPreviousState();
        }
        for (GraphVertex gv : this.vertices) {
            gv.clear();
        }
    }

    public ComputationGraph clone() {
        ComputationGraphUpdater u;
        INDArray updaterState;
        ComputationGraph cg = new ComputationGraph(this.configuration.clone());
        cg.init(this.params().dup(), false);
        if (this.solver != null && (updaterState = (u = this.getUpdater()).getStateViewArray()) != null) {
            cg.getUpdater().setStateViewArray(updaterState.dup());
        }
        cg.listeners = this.listeners;
        for (int i = 0; i < this.topologicalOrder.length; ++i) {
            String layerName;
            if (!this.vertices[this.topologicalOrder[i]].hasLayer() || !(this.getLayer(layerName = this.vertices[this.topologicalOrder[i]].getVertexName()) instanceof FrozenLayer)) continue;
            cg.getVertex(layerName).setLayerAsFrozen();
        }
        return cg;
    }

    public double calcL2() {
        double l2 = 0.0;
        for (org.deeplearning4j.nn.api.Layer l : this.layers) {
            l2 += l.calcL2(true);
        }
        return l2;
    }

    public double calcL1() {
        double l1 = 0.0;
        for (org.deeplearning4j.nn.api.Layer l : this.layers) {
            l1 += l.calcL1(true);
        }
        return l1;
    }

    @Override
    public void setListeners(Collection<IterationListener> listeners) {
        this.listeners = listeners;
        if (this.layers == null) {
            this.init();
        }
        for (org.deeplearning4j.nn.api.Layer l : this.layers) {
            l.setListeners(listeners);
        }
        if (this.solver != null) {
            this.solver.setListeners(listeners);
        }
        this.trainingListeners.clear();
        if (listeners != null) {
            for (IterationListener il : listeners) {
                if (!(il instanceof TrainingListener)) continue;
                this.trainingListeners.add((TrainingListener)il);
            }
        }
    }

    @Override
    public void setListeners(IterationListener ... listeners) {
        ArrayList<IterationListener> list = new ArrayList<IterationListener>();
        if (listeners != null && listeners.length > 0) {
            for (IterationListener i : listeners) {
                if (i == null) continue;
                list.add(i);
            }
        }
        this.setListeners(list);
    }

    @Override
    public void addListeners(IterationListener ... listeners) {
        if (this.listeners == null) {
            this.setListeners(listeners);
            return;
        }
        for (IterationListener listener : listeners) {
            this.listeners.add(listener);
            if (!(listener instanceof TrainingListener)) continue;
            this.trainingListeners.add((TrainingListener)listener);
        }
        if (this.solver != null) {
            this.solver.setListeners(this.listeners);
        }
    }

    public Collection<IterationListener> getListeners() {
        return this.listeners;
    }

    public ComputationGraphUpdater getUpdater() {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
            this.solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this));
        }
        return this.solver.getOptimizer().getComputationGraphUpdater();
    }

    public void setUpdater(ComputationGraphUpdater updater) {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
        }
        this.solver.getOptimizer().setUpdaterComputationGraph(updater);
    }

    public org.deeplearning4j.nn.api.Layer getOutputLayer(int outputLayerIdx) {
        if (outputLayerIdx >= this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid index: cannot get output layer " + outputLayerIdx + ", total number of network outputs = " + this.numOutputArrays);
        }
        return this.getLayer(this.configuration.getNetworkOutputs().get(outputLayerIdx));
    }

    public INDArray params(boolean backwardOnly) {
        if (backwardOnly) {
            return this.flattenedParams;
        }
        ArrayList<INDArray> list = new ArrayList<INDArray>(this.layers.length);
        for (int i = 0; i < this.topologicalOrder.length; ++i) {
            org.deeplearning4j.nn.api.Layer l;
            INDArray layerParams;
            if (!this.vertices[this.topologicalOrder[i]].hasLayer() || (layerParams = (l = this.vertices[this.topologicalOrder[i]].getLayer()).params()) == null) continue;
            list.add(layerParams);
        }
        return Nd4j.toFlattened((char)'f', list);
    }

    public double score(DataSet dataSet) {
        return this.score(dataSet, false);
    }

    public double score(DataSet dataSet, boolean training) {
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot score ComputationGraph network with  DataSet: network does not have 1 input and 1 output arrays");
        }
        return this.score(ComputationGraphUtil.toMultiDataSet(dataSet), training);
    }

    public double score(org.nd4j.linalg.dataset.api.MultiDataSet dataSet) {
        return this.score(dataSet, false);
    }

    public double score(org.nd4j.linalg.dataset.api.MultiDataSet dataSet, boolean training) {
        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            this.setLayerMaskArrays(dataSet.getFeaturesMaskArrays(), dataSet.getLabelsMaskArrays());
        }
        double score = 0.0;
        DummyWorkspace workspace = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, WORKSPACE_EXTERNAL);
        try (MemoryWorkspace ws = workspace.notifyScopeEntered();){
            this.setInputs(dataSet.getFeatures());
            this.feedForward(training, false, false, false);
            INDArray[] labels = dataSet.getLabels();
            this.setLabels(labels);
            double l1 = this.calcL1();
            double l2 = this.calcL2();
            int i = 0;
            for (String s : this.configuration.getNetworkOutputs()) {
                GraphVertex gv = this.verticesMap.get(s);
                org.deeplearning4j.nn.api.Layer outLayer = gv.getLayer();
                if (outLayer == null || !(outLayer instanceof IOutputLayer)) {
                    log.warn("Cannot calculate score: vertex \"" + s + "\" is not an output layer");
                    double d = 0.0;
                    return d;
                }
                IOutputLayer ol = (IOutputLayer)outLayer;
                ol.setLabels(labels[i++]);
                score += ((LayerVertex)gv).computeScore(l1, l2, training);
                l1 = 0.0;
                l2 = 0.0;
            }
        }
        this.clearLayersStates();
        return score;
    }

    public INDArray scoreExamples(DataSet data, boolean addRegularizationTerms) {
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot score ComputationGraph network with  DataSet: network does not have 1 input and 1 output arrays");
        }
        return this.scoreExamples(ComputationGraphUtil.toMultiDataSet(data), addRegularizationTerms);
    }

    public INDArray scoreExamples(org.nd4j.linalg.dataset.api.MultiDataSet data, boolean addRegularizationTerms) {
        boolean hasMaskArray = data.hasMaskArrays();
        if (hasMaskArray) {
            this.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
        }
        this.setInputs(data.getFeatures());
        this.feedForward(false, true, false, false);
        this.setLabels(data.getLabels());
        INDArray out = null;
        double l1 = addRegularizationTerms ? this.calcL1() : 0.0;
        double l2 = addRegularizationTerms ? this.calcL2() : 0.0;
        int i = 0;
        for (String s : this.configuration.getNetworkOutputs()) {
            GraphVertex gv = this.verticesMap.get(s);
            org.deeplearning4j.nn.api.Layer outLayer = gv.getLayer();
            if (outLayer == null || !(outLayer instanceof IOutputLayer)) {
                throw new UnsupportedOperationException("Cannot calculate score: vertex \"" + s + "\" is not an output layer");
            }
            IOutputLayer ol = (IOutputLayer)outLayer;
            ol.setLabels(this.labels[i++]);
            INDArray scoreCurrLayer = ((LayerVertex)gv).computeScoreForExamples(l1, l2);
            if (out == null) {
                out = scoreCurrLayer;
            } else {
                out.addi(scoreCurrLayer);
            }
            l1 = 0.0;
            l2 = 0.0;
        }
        if (hasMaskArray) {
            this.clearLayerMaskArrays();
        }
        this.clearLayersStates();
        return out;
    }

    @Override
    public void fit() {
        this.fit(this.inputs, this.labels, this.inputMaskArrays, this.labelMaskArrays);
    }

    @Override
    public void update(INDArray gradient, String paramType) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void update(Gradient gradient) {
        if (gradient.gradient().length() != this.numParams(true)) {
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + this.numParams(true));
        }
        for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
            String key = entry.getKey();
            INDArray val = entry.getValue();
            int idx = key.indexOf(95);
            if (idx == -1) {
                throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
            }
            String layerName = key.substring(0, idx);
            String paramType = key.split("_")[1];
            this.gradient.gradientForVariable().put(key, val);
            this.getLayer(layerName).update(val, paramType);
        }
        this.setBackpropGradientsViewArray(gradient.gradient());
    }

    private void update(Task task) {
        if (!this.initDone) {
            this.initDone = true;
            Heartbeat heartbeat = Heartbeat.getInstance();
            task = ModelSerializer.taskByModel(this);
            Environment env = EnvironmentUtils.buildEnvironment();
            heartbeat.reportEvent(Event.STANDALONE, env, task);
        }
    }

    @Override
    public double score() {
        return this.score;
    }

    public void setScore(double score) {
        this.score = score;
    }

    @Override
    public void accumulateScore(double accum) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public INDArray params() {
        return this.params(true);
    }

    @Override
    public INDArray updaterState() {
        return this.getUpdater() != null ? this.getUpdater().getUpdaterStateViewArray() : null;
    }

    @Override
    public int numParams() {
        return this.numParams(true);
    }

    @Override
    public int numParams(boolean backwards) {
        int nParams = 0;
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            nParams += layer.numParams(backwards);
        }
        return nParams;
    }

    @Override
    public void setParams(INDArray params) {
        if (params == this.flattenedParams) {
            return;
        }
        if (this.flattenedParams != null && this.flattenedParams.length() == params.length()) {
            this.flattenedParams.assign(params);
            return;
        }
        int idx = 0;
        for (int i = 0; i < this.topologicalOrder.length; ++i) {
            org.deeplearning4j.nn.api.Layer layer;
            int range;
            if (!this.vertices[this.topologicalOrder[i]].hasLayer() || (range = (layer = this.vertices[this.topologicalOrder[i]].getLayer()).numParams()) <= 0) continue;
            INDArray get = params.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)idx, (int)(range + idx))});
            layer.setParams(get);
            idx += range;
        }
    }

    @Override
    public void setParamsViewArray(INDArray gradient) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public INDArray getGradientsViewArray() {
        return this.flattenedGradients;
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradient) {
        int paramsSoFar = 0;
        for (int i = 0; i < this.topologicalOrder.length; ++i) {
            org.deeplearning4j.nn.api.Layer layer;
            int range;
            if (!this.vertices[this.topologicalOrder[i]].hasLayer() || (range = (layer = this.vertices[this.topologicalOrder[i]].getLayer()).numParams()) <= 0) continue;
            layer.setBackpropGradientsViewArray(gradient.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)paramsSoFar, (int)(paramsSoFar + range))}));
            paramsSoFar += range;
        }
    }

    @Override
    public void fit(INDArray data) {
        throw new UnsupportedOperationException("Cannot pretrain ComputationGraph with single INDArray");
    }

    @Override
    public void iterate(INDArray input) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair((Object)this.gradient(), (Object)this.score());
    }

    @Override
    public int batchSize() {
        return this.inputs[0].size(0);
    }

    @Override
    public NeuralNetConfiguration conf() {
        return this.defaultConfiguration;
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray input() {
        if (this.numInputArrays == 1) {
            return this.inputs != null ? this.inputs[0] : null;
        }
        throw new UnsupportedOperationException("Cannot return single input: ComputationGraph  has multiple inputs");
    }

    @Override
    public void validateInput() {
    }

    @Override
    public ConvexOptimizer getOptimizer() {
        return this.solver.getOptimizer();
    }

    @Override
    public INDArray getParam(String paramName) {
        int idx = paramName.indexOf(95);
        if (idx == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + paramName + "\"");
        }
        String layerName = paramName.substring(0, idx);
        String paramType = paramName.substring(idx + 1);
        return this.getLayer(layerName).getParam(paramType);
    }

    @Override
    public void initParams() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return this.paramTable(false);
    }

    @Override
    public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
        LinkedHashMap<String, INDArray> allParams = new LinkedHashMap<String, INDArray>();
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            Map<String, INDArray> paramMap = layer.paramTable(backpropParamsOnly);
            for (Map.Entry<String, INDArray> entry : paramMap.entrySet()) {
                String newKey = layer.conf().getLayer().getLayerName() + "_" + entry.getKey();
                allParams.put(newKey, entry.getValue());
            }
        }
        return allParams;
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void setParam(String key, INDArray val) {
        int idx = key.indexOf(95);
        if (idx == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
        }
        String layerName = key.substring(0, idx);
        String paramType = key.substring(idx + 1);
        this.getLayer(layerName).setParam(paramType, val);
    }

    @Override
    public void clear() {
        this.inputs = null;
        this.labels = null;
        this.inputMaskArrays = null;
        this.labelMaskArrays = null;
    }

    @Override
    public void applyConstraints(int iteration, int epoch) {
        for (org.deeplearning4j.nn.api.Layer l : this.layers) {
            l.applyConstraints(iteration, epoch);
        }
    }

    public INDArray[] rnnTimeStep(INDArray ... inputs) {
        this.inputs = inputs;
        boolean inputIs2d = true;
        for (INDArray i : inputs) {
            if (i.rank() == 2) continue;
            inputIs2d = false;
            break;
        }
        INDArray[] outputs = new INDArray[this.numOutputArrays];
        for (int currVertexIdx : this.topologicalOrder) {
            VertexIndices[] outputsTo;
            org.deeplearning4j.nn.api.Layer l;
            int vIdx;
            GraphVertex current = this.vertices[currVertexIdx];
            if (current.isInputVertex()) {
                VertexIndices[] inputsTo = current.getOutputVertices();
                INDArray input = inputs[current.getVertexIndex()];
                for (VertexIndices v : inputsTo) {
                    vIdx = v.getVertexIndex();
                    int vIdxInputNum = v.getVertexEdgeNumber();
                    this.vertices[vIdx].setInput(vIdxInputNum, input.dup());
                }
                continue;
            }
            INDArray out = current.hasLayer() ? ((l = current.getLayer()) instanceof RecurrentLayer ? ((RecurrentLayer)l).rnnTimeStep(current.getInputs()[0]) : (l instanceof MultiLayerNetwork ? ((MultiLayerNetwork)l).rnnTimeStep(current.getInputs()[0]) : current.doForward(false))) : current.doForward(false);
            if (current.isOutputVertex()) {
                int idx = this.configuration.getNetworkOutputs().indexOf(current.getVertexName());
                outputs[idx] = out;
            }
            if ((outputsTo = current.getOutputVertices()) == null) continue;
            for (VertexIndices v : outputsTo) {
                vIdx = v.getVertexIndex();
                int inputNum = v.getVertexEdgeNumber();
                this.vertices[vIdx].setInput(inputNum, out);
            }
        }
        if (inputIs2d) {
            for (int i = 0; i < outputs.length; ++i) {
                if (outputs[i].rank() != 3 || outputs[i].size(2) != 1) continue;
                outputs[i] = outputs[i].tensorAlongDimension(0, new int[]{1, 0});
            }
        }
        this.inputs = null;
        return outputs;
    }

    public Map<String, INDArray> rnnGetPreviousState(int layer) {
        return this.rnnGetPreviousState(this.layers[layer].conf().getLayer().getLayerName());
    }

    public Map<String, INDArray> rnnGetPreviousState(String layerName) {
        org.deeplearning4j.nn.api.Layer l = this.verticesMap.get(layerName).getLayer();
        if (l == null || !(l instanceof RecurrentLayer)) {
            return null;
        }
        return ((RecurrentLayer)l).rnnGetPreviousState();
    }

    public Map<String, Map<String, INDArray>> rnnGetPreviousStates() {
        HashMap<String, Map<String, INDArray>> states = new HashMap<String, Map<String, INDArray>>();
        for (org.deeplearning4j.nn.api.Layer l : this.layers) {
            if (!(l instanceof RecurrentLayer)) continue;
            states.put(l.conf().getLayer().getLayerName(), ((RecurrentLayer)l).rnnGetPreviousState());
        }
        return states;
    }

    public void rnnSetPreviousState(int layer, Map<String, INDArray> state) {
        this.rnnSetPreviousState(this.layers[layer].conf().getLayer().getLayerName(), state);
    }

    public void rnnSetPreviousState(String layerName, Map<String, INDArray> state) {
        org.deeplearning4j.nn.api.Layer l = this.verticesMap.get(layerName).getLayer();
        if (l == null || !(l instanceof RecurrentLayer)) {
            throw new UnsupportedOperationException("Layer \"" + layerName + "\" is not a recurrent layer. Cannot set state");
        }
        ((RecurrentLayer)l).rnnSetPreviousState(state);
    }

    public void rnnSetPreviousStates(Map<String, Map<String, INDArray>> previousStates) {
        for (Map.Entry<String, Map<String, INDArray>> entry : previousStates.entrySet()) {
            this.rnnSetPreviousState(entry.getKey(), entry.getValue());
        }
    }

    public void rnnClearPreviousState() {
        if (this.layers == null) {
            return;
        }
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            if (layer instanceof RecurrentLayer) {
                ((RecurrentLayer)layer).rnnClearPreviousState();
                continue;
            }
            if (!(layer instanceof MultiLayerNetwork)) continue;
            ((MultiLayerNetwork)layer).rnnClearPreviousState();
        }
    }

    protected void doTruncatedBPTT(INDArray[] inputs, INDArray[] labels, INDArray[] featureMasks, INDArray[] labelMasks) {
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        int timeSeriesLength = -1;
        for (INDArray in : inputs) {
            if (in.rank() != 3) continue;
            if (timeSeriesLength == -1) {
                timeSeriesLength = in.size(2);
                continue;
            }
            if (timeSeriesLength == in.size(2)) continue;
            log.warn("Cannot do TBPTT with time series of different lengths");
            return;
        }
        for (INDArray out : labels) {
            if (out.rank() != 3) continue;
            if (timeSeriesLength == -1) {
                timeSeriesLength = out.size(2);
                continue;
            }
            if (timeSeriesLength == out.size(2)) continue;
            log.warn("Cannot do TBPTT with time series of different lengths");
            return;
        }
        int fwdLen = this.configuration.getTbpttFwdLength();
        int nSubsets = timeSeriesLength / fwdLen;
        if (timeSeriesLength % fwdLen != 0) {
            ++nSubsets;
        }
        this.rnnClearPreviousState();
        workspaceConfigurationExternal.setCyclesBeforeInitialization(0);
        workspaceConfigurationExternal.setPolicyLearning(LearningPolicy.OVER_TIME);
        DummyWorkspace workspaceT = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationTBPTT, WORKSPACE_TBPTT);
        DummyWorkspace workspace = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, WORKSPACE_EXTERNAL);
        try (MemoryWorkspace wsT = workspaceT.notifyScopeEntered();){
            for (int i = 0; i < nSubsets; ++i) {
                try (MemoryWorkspace wsE = workspace.notifyScopeEntered();){
                    int startTimeIdx = i * fwdLen;
                    int endTimeIdx = startTimeIdx + fwdLen;
                    if (endTimeIdx > timeSeriesLength) {
                        endTimeIdx = timeSeriesLength;
                    }
                    List<INDArray[]> list = this.getSubsetsForTbptt(startTimeIdx, endTimeIdx, inputs, labels, featureMasks, labelMasks);
                    this.setInputs(list.get(0));
                    this.setLabels(list.get(1));
                    this.setLayerMaskArrays(list.get(2), list.get(3));
                    if (this.solver == null) {
                        try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                        }
                    }
                    this.solver.optimize();
                    this.rnnUpdateStateWithTBPTTState();
                    continue;
                }
            }
        }
        if (this.configuration.getTrainingWorkspaceMode() != WorkspaceMode.NONE) {
            workspace.initializeWorkspace();
            workspaceT.initializeWorkspace();
        }
        this.rnnClearPreviousState();
        if (featureMasks != null || labelMasks != null) {
            this.clearLayerMaskArrays();
        }
    }

    private List<INDArray[]> getSubsetsForTbptt(int startTimeIdx, int endTimeIdx, INDArray[] inputs, INDArray[] labels, INDArray[] featureMasks, INDArray[] labelMasks) {
        int j;
        INDArray[] newInputs = new INDArray[inputs.length];
        INDArray[] newLabels = new INDArray[inputs.length];
        INDArray[] newFeatureMasks = featureMasks != null ? new INDArray[featureMasks.length] : null;
        INDArray[] newLabelMasks = labelMasks != null ? new INDArray[labelMasks.length] : null;
        for (j = 0; j < inputs.length; ++j) {
            newInputs[j] = inputs[j].rank() != 3 ? inputs[j] : inputs[j].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
        }
        for (j = 0; j < labels.length; ++j) {
            newLabels[j] = labels[j].rank() != 3 ? labels[j] : labels[j].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
        }
        if (featureMasks != null) {
            for (j = 0; j < featureMasks.length; ++j) {
                if (featureMasks[j] == null) continue;
                newFeatureMasks[j] = featureMasks[j].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
            }
        }
        if (labelMasks != null) {
            for (j = 0; j < labelMasks.length; ++j) {
                if (labelMasks[j] == null) continue;
                newLabelMasks[j] = labelMasks[j].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
            }
        }
        return Arrays.asList(newInputs, newLabels, newFeatureMasks, newLabelMasks);
    }

    public Map<String, INDArray> rnnActivateUsingStoredState(INDArray[] inputs, boolean training, boolean storeLastForTBPTT) {
        HashMap<String, INDArray> layerActivations = new HashMap<String, INDArray>();
        for (int currVertexIdx : this.topologicalOrder) {
            INDArray out;
            int vIdx;
            GraphVertex current = this.vertices[currVertexIdx];
            if (current.isInputVertex()) {
                VertexIndices[] inputsTo = current.getOutputVertices();
                INDArray input = inputs[current.getVertexIndex()];
                layerActivations.put(current.getVertexName(), input);
                for (VertexIndices v : inputsTo) {
                    vIdx = v.getVertexIndex();
                    int vIdxInputNum = v.getVertexEdgeNumber();
                    this.vertices[vIdx].setInput(vIdxInputNum, input.dup());
                }
                continue;
            }
            if (current.hasLayer()) {
                org.deeplearning4j.nn.api.Layer l = current.getLayer();
                if (l instanceof RecurrentLayer) {
                    out = ((RecurrentLayer)l).rnnActivateUsingStoredState(current.getInputs()[0], training, storeLastForTBPTT);
                } else if (l instanceof MultiLayerNetwork) {
                    List<INDArray> temp = ((MultiLayerNetwork)l).rnnActivateUsingStoredState(current.getInputs()[0], training, storeLastForTBPTT);
                    out = temp.get(temp.size() - 1);
                } else {
                    out = current.doForward(training);
                }
                layerActivations.put(current.getVertexName(), out);
            } else {
                out = current.doForward(training);
            }
            VertexIndices[] outputsTo = current.getOutputVertices();
            if (outputsTo == null) continue;
            for (VertexIndices v : outputsTo) {
                vIdx = v.getVertexIndex();
                int inputNum = v.getVertexEdgeNumber();
                this.vertices[vIdx].setInput(inputNum, out);
            }
        }
        return layerActivations;
    }

    public void setLayerMaskArrays(INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
        this.clearLayerMaskArrays();
        this.inputMaskArrays = featureMaskArrays;
        this.labelMaskArrays = labelMaskArrays;
        if (featureMaskArrays != null) {
            if (featureMaskArrays.length != this.numInputArrays) {
                throw new IllegalArgumentException("Invalid number of feature mask arrays");
            }
            int minibatchSize = -1;
            for (INDArray i : featureMaskArrays) {
                if (i == null) continue;
                minibatchSize = i.size(0);
            }
            HashMap<Integer, Object> map = new HashMap<Integer, Object>();
            for (int i = 0; i < this.topologicalOrder.length; ++i) {
                GraphVertex current = this.vertices[this.topologicalOrder[i]];
                if (current.isInputVertex()) {
                    INDArray fMask = featureMaskArrays[current.getVertexIndex()];
                    map.put(current.getVertexIndex(), new Pair((Object)fMask, (Object)MaskState.Active));
                    continue;
                }
                VertexIndices[] inputVertices = current.getInputVertices();
                INDArray[] inputMasks = null;
                MaskState maskState = null;
                for (int j = 0; j < inputVertices.length; ++j) {
                    Pair p = (Pair)map.get(inputVertices[j].getVertexIndex());
                    if (p == null) continue;
                    if (inputMasks == null) {
                        inputMasks = new INDArray[inputVertices.length];
                    }
                    inputMasks[j] = (INDArray)p.getFirst();
                    if (maskState != null && maskState != MaskState.Passthrough) continue;
                    maskState = (MaskState)((Object)p.getSecond());
                }
                Pair<INDArray, MaskState> outPair = current.feedForwardMaskArrays(inputMasks, maskState, minibatchSize);
                map.put(this.topologicalOrder[i], outPair);
            }
        }
        if (labelMaskArrays != null) {
            if (labelMaskArrays.length != this.numOutputArrays) {
                throw new IllegalArgumentException("Invalid number of label mask arrays");
            }
            for (int i = 0; i < labelMaskArrays.length; ++i) {
                if (labelMaskArrays[i] == null) continue;
                String outputName = this.configuration.getNetworkOutputs().get(i);
                GraphVertex v = this.verticesMap.get(outputName);
                org.deeplearning4j.nn.api.Layer ol = v.getLayer();
                ol.setMaskArray(labelMaskArrays[i]);
            }
        }
    }

    public void clearLayerMaskArrays() {
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            layer.setMaskArray(null);
        }
        this.inputMaskArrays = null;
        this.labelMaskArrays = null;
    }

    protected void rnnUpdateStateWithTBPTTState() {
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.layers[i] instanceof RecurrentLayer) {
                RecurrentLayer l = (RecurrentLayer)this.layers[i];
                l.rnnSetPreviousState(l.rnnGetTBPTTState());
                continue;
            }
            if (!(this.layers[i] instanceof MultiLayerNetwork)) continue;
            ((MultiLayerNetwork)this.layers[i]).updateRnnStateWithTBPTTState();
        }
    }

    public Evaluation evaluate(DataSetIterator iterator) {
        return this.evaluate(iterator, null);
    }

    public Evaluation evaluate(MultiDataSetIterator iterator) {
        return this.evaluate(iterator, null);
    }

    public Evaluation evaluate(DataSetIterator iterator, List<String> labelsList) {
        return this.evaluate(iterator, labelsList, 1);
    }

    public Evaluation evaluate(MultiDataSetIterator iterator, List<String> labelsList) {
        return this.evaluate(iterator, labelsList, 1);
    }

    public Evaluation evaluate(DataSetIterator iterator, List<String> labelsList, int topN) {
        if (labelsList == null) {
            labelsList = iterator.getLabels();
        }
        return ((Evaluation[])this.doEvaluation(iterator, new Evaluation[]{new Evaluation((List<String>)labelsList, topN)}))[0];
    }

    public Evaluation evaluate(MultiDataSetIterator iterator, List<String> labelsList, int topN) {
        return ((Evaluation[])this.doEvaluation(iterator, new Evaluation[]{new Evaluation(labelsList, topN)}))[0];
    }

    public RegressionEvaluation evaluateRegression(DataSetIterator iterator) {
        return this.evaluateRegression(iterator, null);
    }

    public RegressionEvaluation evaluateRegression(MultiDataSetIterator iterator) {
        return this.evaluateRegression(iterator, null);
    }

    public RegressionEvaluation evaluateRegression(DataSetIterator iterator, List<String> columnNames) {
        return ((RegressionEvaluation[])this.doEvaluation(iterator, new RegressionEvaluation[]{new RegressionEvaluation(columnNames)}))[0];
    }

    public RegressionEvaluation evaluateRegression(MultiDataSetIterator iterator, List<String> columnNames) {
        return ((RegressionEvaluation[])this.doEvaluation(iterator, new RegressionEvaluation[]{new RegressionEvaluation(columnNames)}))[0];
    }

    public ROC evaluateROC(DataSetIterator iterator) {
        return this.evaluateROC(iterator, 0);
    }

    public ROC evaluateROC(DataSetIterator iterator, int rocThresholdSteps) {
        return ((ROC[])this.doEvaluation(iterator, new ROC[]{new ROC(rocThresholdSteps)}))[0];
    }

    public ROC evaluateROC(MultiDataSetIterator iterator) {
        return this.evaluateROC(iterator, 0);
    }

    public ROC evaluateROC(MultiDataSetIterator iterator, int rocThresholdSteps) {
        return ((ROC[])this.doEvaluation(iterator, new ROC[]{new ROC(rocThresholdSteps)}))[0];
    }

    public ROCMultiClass evaluateROCMultiClass(DataSetIterator iterator) {
        return this.evaluateROCMultiClass(iterator, 0);
    }

    public ROCMultiClass evaluateROCMultiClass(DataSetIterator iterator, int rocThresholdSteps) {
        return ((ROCMultiClass[])this.doEvaluation(iterator, new ROCMultiClass[]{new ROCMultiClass(rocThresholdSteps)}))[0];
    }

    public ROCMultiClass evaluateROCMultiClass(MultiDataSetIterator iterator, int rocThresholdSteps) {
        return ((ROCMultiClass[])this.doEvaluation(iterator, new ROCMultiClass[]{new ROCMultiClass(rocThresholdSteps)}))[0];
    }

    @Override
    public <T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator, T ... evaluations) {
        return this.doEvaluation(new MultiDataSetIteratorAdapter(iterator), (IEvaluation[])evaluations);
    }

    @Override
    public <T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator, T ... evaluations) {
        org.nd4j.linalg.dataset.api.MultiDataSet next;
        boolean useRnnSegments;
        if (this.layers == null || !(this.getOutputLayer(0) instanceof IOutputLayer)) {
            throw new IllegalStateException("Cannot evaluate network with no output layer");
        }
        if (this.getNumOutputArrays() != 1) {
            throw new IllegalStateException("Cannot evaluate a model using this method with > 1 output arrays");
        }
        if (iterator.resetSupported() && !iterator.hasNext()) {
            iterator.reset();
        }
        MultiDataSetIterator iter = iterator.asyncSupported() ? new AsyncMultiDataSetIterator(iterator, 2, true) : iterator;
        WorkspaceMode cMode = this.configuration.getTrainingWorkspaceMode();
        this.configuration.setTrainingWorkspaceMode(this.configuration.getInferenceWorkspaceMode());
        DummyWorkspace workspace = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, WORKSPACE_EXTERNAL);
        DummyWorkspace workspaceT = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationTBPTT, WORKSPACE_TBPTT);
        boolean bl = useRnnSegments = this.configuration.getBackpropType() == BackpropType.TruncatedBPTT;
        while (iter.hasNext() && (next = (org.nd4j.linalg.dataset.api.MultiDataSet)iter.next()).getFeatures() != null && next.getLabels() != null) {
            MemoryWorkspace wsB = workspace.notifyScopeEntered();
            Throwable throwable = null;
            try {
                if (!useRnnSegments) {
                    INDArray[] features = next.getFeatures();
                    INDArray[] featuresMasks = next.getFeaturesMaskArrays();
                    INDArray labels = next.getLabels(0);
                    INDArray[] labelMasks = next.getLabelsMaskArrays();
                    INDArray labelMask = next.getLabelsMaskArray(0);
                    this.setLayerMaskArrays(featuresMasks, labelMasks);
                    INDArray[] out = this.silentOutput(false, features);
                    try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                        for (T evaluation : evaluations) {
                            evaluation.eval(labels, out[0], labelMask);
                        }
                    }
                }
                this.rnnClearPreviousState();
                int fwdLen = this.configuration.getTbpttFwdLength();
                int tsLength = -1;
                int nF = next.getFeatures().length;
                for (int i = 0; i < nF; ++i) {
                    if (next.getFeatures(i).rank() != 3) continue;
                    tsLength = next.getFeatures(i).size(2);
                }
                if (tsLength < 0) {
                    throw new IllegalStateException("Invalid configuration: detected TBPTT backprop type without time series features");
                }
                int nSubsets = tsLength / fwdLen;
                if (tsLength % fwdLen != 0) {
                    ++nSubsets;
                }
                for (int i = 0; i < nSubsets; ++i) {
                    int startTimeIdx = i * fwdLen;
                    int endTimeIdx = Math.min(startTimeIdx + fwdLen, tsLength);
                    List<INDArray[]> subset = this.getSubsetsForTbptt(startTimeIdx, endTimeIdx, next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
                    try (MemoryWorkspace wsT = workspaceT.notifyScopeEntered();){
                        this.setLayerMaskArrays(subset.get(2), subset.get(3));
                        INDArray[] outSub = this.rnnTimeStep(subset.get(0));
                        INDArray maskSub = subset.get(3) == null ? null : subset.get(3)[0];
                        try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                            for (T evaluation : evaluations) {
                                evaluation.eval(subset.get(1)[0], outSub[0], maskSub);
                            }
                            continue;
                        }
                    }
                }
                this.clearLayersStates();
            }
            catch (Throwable throwable2) {
                throwable = throwable2;
                throw throwable2;
            }
            finally {
                if (wsB == null) continue;
                if (throwable != null) {
                    try {
                        wsB.close();
                    }
                    catch (Throwable throwable3) {
                        throwable.addSuppressed(throwable3);
                    }
                    continue;
                }
                wsB.close();
            }
        }
        if (iterator.asyncSupported()) {
            ((AsyncMultiDataSetIterator)iter).shutdown();
        }
        this.configuration.setTrainingWorkspaceMode(cMode);
        return evaluations;
    }

    public String summary() {
        return this.summary(null);
    }

    public String summary(InputType ... inputTypes) {
        String ret = "\n";
        ret = ret + StringUtils.repeat((String)"=", (int)250);
        ret = ret + "\n";
        if (inputTypes != null) {
            if (inputTypes.length != this.configuration.getNetworkInputs().size()) {
                throw new IllegalArgumentException("The number of inputTypes should match the size of the inputs in the computation graph");
            }
            ret = ret + String.format("%-40s%-10s%-12s%-40s%-30s%-75s%-75s\n", "VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs", "InputShape", "OutputShape");
        } else {
            ret = ret + String.format("%-40s%-10s%-12s%-40s%-30s\n", "VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs");
        }
        ret = ret + StringUtils.repeat((String)"=", (int)250);
        ret = ret + "\n";
        int frozenParams = 0;
        HashMap<String, InputType> vertexOutputs = new HashMap<String, InputType>();
        int currLayerIdx = -1;
        for (int currVertexIdx : this.topologicalOrder) {
            GraphVertex currentVertex = this.vertices[currVertexIdx];
            String currentVertexName = currentVertex.getVertexName();
            String[] classNameArr = currentVertex.getClass().toString().split("\\.");
            String className = classNameArr[classNameArr.length - 1];
            String connections = "-";
            String inShape = "-";
            String outShape = "-";
            String paramCount = "-";
            String in = "-";
            String out = "-";
            String paramShape = "-";
            if (currentVertex.isInputVertex()) {
                if (inputTypes != null) {
                    vertexOutputs.put(currentVertexName, inputTypes[this.configuration.getNetworkInputs().indexOf(currentVertexName)]);
                }
            } else {
                VertexIndices[] inputVertices;
                connections = this.configuration.getVertexInputs().get(currentVertexName).toString();
                ArrayList<InputType> inputTypeList = new ArrayList<InputType>();
                if (currentVertex.hasLayer()) {
                    org.deeplearning4j.nn.api.Layer currentLayer = ((LayerVertex)currentVertex).getLayer();
                    classNameArr = currentLayer.getClass().getName().split("\\.");
                    className = classNameArr[classNameArr.length - 1];
                    paramCount = String.valueOf(currentLayer.numParams());
                    if (currentLayer.numParams() > 0) {
                        paramShape = "";
                        in = String.valueOf(((FeedForwardLayer)currentLayer.conf().getLayer()).getNIn());
                        out = String.valueOf(((FeedForwardLayer)currentLayer.conf().getLayer()).getNOut());
                        List<String> paraNames = currentLayer.conf().variables();
                        for (String aP : paraNames) {
                            String paramS = ArrayUtils.toString((Object)currentLayer.paramTable().get(aP).shape());
                            paramShape = paramShape + aP + ":" + paramS + ", ";
                        }
                        paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString();
                    }
                    if (currentLayer instanceof FrozenLayer) {
                        frozenParams += currentLayer.numParams();
                        classNameArr = ((FrozenLayer)currentLayer).getInsideLayer().getClass().getName().split("\\.");
                        className = "Frozen " + classNameArr[classNameArr.length - 1];
                    }
                    if (inputTypes != null) {
                        String inputVertexName = this.vertices[currentVertex.getInputVertices()[0].getVertexIndex()].getVertexName();
                        InputType currentInType = (InputType)vertexOutputs.get(inputVertexName);
                        inShape = currentInType.toString();
                        inputTypeList.add(currentInType);
                        InputPreProcessor layerVertexPreProcesor = ((org.deeplearning4j.nn.conf.graph.LayerVertex)this.configuration.getVertices().get(currentVertexName)).getPreProcessor();
                        if (layerVertexPreProcesor != null) {
                            inShape = inShape + "-->" + layerVertexPreProcesor.getOutputType(currentInType);
                        }
                    }
                    ++currLayerIdx;
                } else if (inputTypes != null && (inputVertices = currentVertex.getInputVertices()) != null) {
                    for (int i = 0; i < inputVertices.length; ++i) {
                        GraphVertex thisInputVertex = this.vertices[inputVertices[i].getVertexIndex()];
                        inputTypeList.add((InputType)vertexOutputs.get(thisInputVertex.getVertexName()));
                    }
                }
                if (inputTypes != null) {
                    InputType currentVertexOutputType = this.configuration.getVertices().get(currentVertexName).getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()]));
                    outShape = currentVertexOutputType.toString();
                    vertexOutputs.put(currentVertexName, currentVertexOutputType);
                }
            }
            ret = inputTypes != null ? ret + String.format("%-40s%-10s%-12s%-40s%-30s%-75s%-75s", currentVertexName + " (" + className + ")", in + "," + out, paramCount, paramShape, connections, inShape, outShape) : ret + String.format("%-40s%-10s%-12s%-40s%-30s", currentVertexName + " (" + className + ")", in + "," + out, paramCount, paramShape, connections);
            ret = ret + "\n";
        }
        ret = ret + StringUtils.repeat((String)"-", (int)250);
        ret = ret + String.format("\n%30s %d", "Total Parameters: ", this.params().length());
        ret = ret + String.format("\n%30s %d", "Trainable Parameters: ", this.params().length() - frozenParams);
        ret = ret + String.format("\n%30s %d", "Frozen Parameters: ", frozenParams);
        ret = ret + "\n";
        ret = ret + StringUtils.repeat((String)"=", (int)250);
        ret = ret + "\n";
        return ret;
    }

    protected void clearLayersStates() {
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            layer.clear();
            layer.clearNoiseWeightParams();
        }
        for (Serializable serializable : this.vertices) {
            serializable.clearVertex();
        }
    }

    public void incrementEpochCount() {
        this.configuration.setEpochCount(this.configuration.getEpochCount() + 1);
    }

    protected void synchronizeIterEpochCounts() {
        int currIter = this.getConfiguration().getIterationCount();
        int currEpoch = this.getConfiguration().getEpochCount();
        for (org.deeplearning4j.nn.api.Layer l : this.layers) {
            l.setIterationCount(currIter);
            l.setEpochCount(currEpoch);
        }
    }

    public void save(File f) throws IOException {
        this.save(f, true);
    }

    public void save(File f, boolean saveUpdater) throws IOException {
        ModelSerializer.writeModel((Model)this, f, saveUpdater);
    }

    public static ComputationGraph load(File f, boolean loadUpdater) throws IOException {
        return ModelSerializer.restoreComputationGraph(f, loadUpdater);
    }

    public void setLearningRate(double newLr) {
        NetworkUtils.setLearningRate(this, newLr);
    }

    public void setLearningRate(ISchedule newLr) {
        NetworkUtils.setLearningRate(this, newLr);
    }

    public void setLearningRate(String layerName, double newLr) {
        NetworkUtils.setLearningRate(this, layerName, newLr);
    }

    public void setLearningRate(String layerName, ISchedule newLr) {
        NetworkUtils.setLearningRate(this, layerName, newLr);
    }

    public int layerSize(int layer) {
        if (layer < 0 || layer > this.layers.length) {
            throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and " + (this.layers.length - 1) + " inclusive");
        }
        return this.layerSize(this.layers[layer].conf().getLayer().getLayerName());
    }

    public int layerSize(String layerName) {
        org.deeplearning4j.nn.api.Layer l = this.getLayer(layerName);
        if (l == null) {
            throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists");
        }
        Layer conf = l.conf().getLayer();
        if (conf == null || !(conf instanceof FeedForwardLayer)) {
            return 0;
        }
        FeedForwardLayer ffl = (FeedForwardLayer)conf;
        return ffl.getNOut();
    }

    public boolean equals(Object obj) {
        if (obj == null) {
            return false;
        }
        if (obj instanceof ComputationGraph) {
            ComputationGraph network = (ComputationGraph)obj;
            boolean paramsEquals = network.params().equals(this.params());
            boolean confEquals = this.getConfiguration().equals(network.getConfiguration());
            boolean updaterEquals = this.getUpdater().equals(network.getUpdater());
            return paramsEquals && confEquals && updaterEquals;
        }
        return false;
    }

    public INDArray getFlattenedGradients() {
        return this.flattenedGradients;
    }

    public void setInitDone(boolean initDone) {
        this.initDone = initDone;
    }
}

