/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.multilayer;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.ROC;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.api.FwdPassType;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.updater.MultiLayerUpdater;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.NetworkUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.workspace.ND4JWorkspaceException;
import org.nd4j.linalg.workspace.WorkspaceUtils;
import org.nd4j.util.OneTimeLogger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultiLayerNetwork
implements Serializable,
Classifier,
org.deeplearning4j.nn.api.Layer,
NeuralNetwork {
    private static final Logger log = LoggerFactory.getLogger(MultiLayerNetwork.class);
    protected org.deeplearning4j.nn.api.Layer[] layers;
    protected LinkedHashMap<String, org.deeplearning4j.nn.api.Layer> layerMap = new LinkedHashMap();
    protected INDArray input;
    protected INDArray labels;
    protected boolean initCalled = false;
    protected Collection<TrainingListener> trainingListeners = new ArrayList<TrainingListener>();
    protected NeuralNetConfiguration defaultConfiguration;
    protected MultiLayerConfiguration layerWiseConfigurations;
    protected Gradient gradient;
    protected double score;
    protected boolean initDone = false;
    protected INDArray flattenedParams;
    protected transient INDArray flattenedGradients;
    protected boolean clearTbpttState = true;
    protected transient ThreadLocal<Long> lastEtlTime = new ThreadLocal();
    protected INDArray mask;
    protected int layerIndex;
    protected transient Solver solver;
    protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM";
    protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT";
    protected static final String WS_LAYER_ACT_1 = "WS_LAYER_ACT_1";
    protected static final String WS_LAYER_ACT_2 = "WS_LAYER_ACT_2";
    protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM";
    protected final WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG;
    protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.05).policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    protected final WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG;
    protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT).policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.FIRST_LOOP).build();

    public MultiLayerNetwork(MultiLayerConfiguration conf) {
        this.layerWiseConfigurations = conf;
        this.defaultConfiguration = conf.getConf(0).clone();
        int numWorkingMem = 2 * (this.layerWiseConfigurations.getConfs().size() + this.layerWiseConfigurations.getInputPreProcessors().size());
        this.WS_LAYER_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.02).policyLearning(LearningPolicy.OVER_TIME).cyclesBeforeInitialization(numWorkingMem).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
        this.WS_LAYER_ACT_X_CONFIG = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.02).policyLearning(LearningPolicy.OVER_TIME).cyclesBeforeInitialization(this.layerWiseConfigurations.getConfs().size()).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    }

    @Override
    public void setCacheMode(CacheMode mode) {
        if (mode == null) {
            mode = CacheMode.NONE;
        }
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            layer.setCacheMode(mode);
        }
    }

    public void setLastEtlTime(long time) {
        this.lastEtlTime.set(time);
    }

    public long getLastEtlTime() {
        Long time = this.lastEtlTime.get();
        return time == null ? 0L : time;
    }

    public MultiLayerNetwork(String conf, INDArray params) {
        this(MultiLayerConfiguration.fromJson(conf));
        this.init();
        this.setParameters(params);
    }

    public MultiLayerNetwork(MultiLayerConfiguration conf, INDArray params) {
        this(conf);
        this.init();
        this.setParameters(params);
    }

    protected void intializeConfigurations() {
        if (this.layerWiseConfigurations == null) {
            this.layerWiseConfigurations = new MultiLayerConfiguration.Builder().build();
        }
        if (this.layers == null) {
            this.layers = new org.deeplearning4j.nn.api.Layer[this.getnLayers()];
        }
        if (this.defaultConfiguration == null) {
            this.defaultConfiguration = new NeuralNetConfiguration.Builder().build();
        }
    }

    public void pretrain(DataSetIterator iter) {
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        if (!this.layerWiseConfigurations.isPretrain()) {
            return;
        }
        for (int i = 0; i < this.getnLayers(); ++i) {
            this.pretrainLayer(i, iter);
        }
    }

    public void pretrainLayer(int layerIdx, DataSetIterator iter) {
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        if (!this.layerWiseConfigurations.isPretrain()) {
            return;
        }
        if (layerIdx >= this.layers.length) {
            throw new IllegalArgumentException("Cannot pretrain layer: layerIdx (" + layerIdx + ") >= numLayers (" + this.layers.length + ")");
        }
        org.deeplearning4j.nn.api.Layer layer = this.layers[layerIdx];
        if (!layer.isPretrainLayer()) {
            return;
        }
        if (!iter.hasNext() && iter.resetSupported()) {
            iter.reset();
        }
        log.info("Starting unsupervised training on layer " + layerIdx);
        while (iter.hasNext()) {
            org.nd4j.linalg.dataset.DataSet next = (org.nd4j.linalg.dataset.DataSet)iter.next();
            this.input = next.getFeatureMatrix();
            this.pretrainLayer(layerIdx, this.input);
        }
        int ec = this.getLayer(layerIdx).conf().getEpochCount() + 1;
        this.getLayer(layerIdx).conf().setEpochCount(ec);
    }

    public void pretrainLayer(int layerIdx, INDArray features) {
        this.setInput(features);
        this.setLayerMaskArrays(null, null);
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        if (!this.layerWiseConfigurations.isPretrain()) {
            return;
        }
        if (layerIdx >= this.layers.length) {
            throw new IllegalArgumentException("Cannot pretrain layer: layerIdx (" + layerIdx + ") >= numLayers (" + this.layers.length + ")");
        }
        LayerWorkspaceMgr workspaceMgr = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().defaultWorkspace(WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
        org.deeplearning4j.nn.api.Layer layer = this.layers[layerIdx];
        if (!layer.isPretrainLayer()) {
            return;
        }
        layer.conf().setPretrain(true);
        INDArray outputOfPrevLayer = layerIdx == 0 ? this.input : this.outputOfLayerDetached(false, FwdPassType.STANDARD, this.layerIndex - 1, features, null, null);
        try (MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM);){
            if (this.layerWiseConfigurations.getInputPreProcess(layerIdx) != null) {
                outputOfPrevLayer = this.layerWiseConfigurations.getInputPreProcess(layerIdx).preProcess(outputOfPrevLayer, this.input.size(0), LayerWorkspaceMgr.noWorkspaces());
            }
            layer.fit(outputOfPrevLayer, workspaceMgr);
        }
        layer.conf().setPretrain(false);
    }

    @Override
    public int batchSize() {
        return this.input.size(0);
    }

    @Override
    public NeuralNetConfiguration conf() {
        return this.defaultConfiguration;
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray input() {
        return this.input;
    }

    @Override
    public void validateInput() {
    }

    @Override
    public ConvexOptimizer getOptimizer() {
        return this.solver.getOptimizer();
    }

    @Override
    public INDArray getParam(String param) {
        int idx = param.indexOf(95);
        if (idx == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + param + "\"");
        }
        int layerIdx = Integer.parseInt(param.substring(0, idx));
        String newKey = param.substring(idx + 1);
        return this.layers[layerIdx].getParam(newKey);
    }

    @Override
    public void initParams() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return this.paramTable(false);
    }

    @Override
    public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
        LinkedHashMap<String, INDArray> allParams = new LinkedHashMap<String, INDArray>();
        for (int i = 0; i < this.layers.length; ++i) {
            Map<String, INDArray> paramMap = this.layers[i].paramTable(backpropParamsOnly);
            for (Map.Entry<String, INDArray> entry : paramMap.entrySet()) {
                String newKey = i + "_" + entry.getKey();
                allParams.put(newKey, entry.getValue());
            }
        }
        return allParams;
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        INDArray toSet;
        INDArray curr;
        Map<String, INDArray> currParamTable = this.paramTable();
        if (!currParamTable.keySet().equals(paramTable.keySet())) {
            throw new IllegalArgumentException("Cannot set param table: parameter keys do not match.\nCurrent: " + currParamTable.keySet() + "\nTo set: " + paramTable.keySet());
        }
        for (String s : paramTable.keySet()) {
            curr = currParamTable.get(s);
            toSet = paramTable.get(s);
            if (Arrays.equals(curr.shape(), toSet.shape())) continue;
            throw new IllegalArgumentException("Cannot set parameter table: parameter \"" + s + "\" shapes do not match. Current = " + Arrays.toString(curr.shape()) + ", to set = " + Arrays.toString(toSet.shape()));
        }
        for (String s : paramTable.keySet()) {
            curr = currParamTable.get(s);
            toSet = paramTable.get(s);
            curr.assign(toSet);
        }
    }

    @Override
    public void setParam(String key, INDArray val) {
        int idx = key.indexOf(95);
        if (idx == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
        }
        int layerIdx = Integer.parseInt(key.substring(0, idx));
        String newKey = key.substring(idx + 1);
        this.layers[layerIdx].setParam(newKey, val);
    }

    public MultiLayerConfiguration getLayerWiseConfigurations() {
        return this.layerWiseConfigurations;
    }

    public void setLayerWiseConfigurations(MultiLayerConfiguration layerWiseConfigurations) {
        this.layerWiseConfigurations = layerWiseConfigurations;
    }

    @Override
    public void init() {
        this.init(null, false);
    }

    public void init(INDArray parameters, boolean cloneParametersArray) {
        if (this.layerWiseConfigurations == null || this.layers == null) {
            this.intializeConfigurations();
        }
        if (this.initCalled) {
            return;
        }
        if (this.layerWiseConfigurations.getTrainingWorkspaceMode() == null) {
            this.layerWiseConfigurations.setTrainingWorkspaceMode(WorkspaceMode.NONE);
        }
        if (this.layerWiseConfigurations.getInferenceWorkspaceMode() == null) {
            this.layerWiseConfigurations.setInferenceWorkspaceMode(WorkspaceMode.NONE);
        }
        if (this.layerWiseConfigurations.getCacheMode() == null) {
            this.layerWiseConfigurations.setCacheMode(CacheMode.NONE);
        }
        OneTimeLogger.info((Logger)log, (String)"Starting MultiLayerNetwork with WorkspaceModes set to [training: {}; inference: {}], cacheMode set to [{}]", (Object[])new Object[]{this.layerWiseConfigurations.getTrainingWorkspaceMode(), this.layerWiseConfigurations.getInferenceWorkspaceMode(), this.layerWiseConfigurations.getCacheMode()});
        int nLayers = this.getnLayers();
        if (nLayers < 1) {
            throw new IllegalStateException("Unable to create network: number of layers is less than 1");
        }
        if (this.layers == null || this.layers[0] == null) {
            boolean initializeParams;
            if (this.layers == null) {
                this.layers = new org.deeplearning4j.nn.api.Layer[nLayers];
            }
            int paramLength = 0;
            int[] nParamsPerLayer = new int[nLayers];
            for (int i = 0; i < nLayers; ++i) {
                NeuralNetConfiguration conf = this.layerWiseConfigurations.getConf(i);
                nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf);
                paramLength += nParamsPerLayer[i];
            }
            if (parameters != null) {
                if (!parameters.isRowVectorOrScalar()) {
                    throw new IllegalArgumentException("Invalid parameters: should be a row vector");
                }
                if (parameters.length() != paramLength) {
                    throw new IllegalArgumentException("Invalid parameters: expected length " + paramLength + ", got length " + parameters.length());
                }
                this.flattenedParams = cloneParametersArray ? parameters.dup() : parameters;
                initializeParams = false;
            } else if (paramLength > 0) {
                this.flattenedParams = Nd4j.create((int)1, (int)paramLength);
                initializeParams = true;
            } else {
                this.flattenedParams = null;
                initializeParams = false;
            }
            if (initializeParams) {
                Nd4j.getRandom().setSeed(this.getDefaultConfiguration().getSeed());
            }
            int paramCountSoFar = 0;
            for (int i = 0; i < nLayers; ++i) {
                INDArray paramsView = nParamsPerLayer[i] > 0 ? this.flattenedParams.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)paramCountSoFar, (int)(paramCountSoFar + nParamsPerLayer[i]))}) : null;
                paramCountSoFar += nParamsPerLayer[i];
                NeuralNetConfiguration conf = this.layerWiseConfigurations.getConf(i);
                this.layers[i] = conf.getLayer().instantiate(conf, this.trainingListeners, i, paramsView, initializeParams);
                this.layerMap.put(conf.getLayer().getLayerName(), this.layers[i]);
            }
            this.initCalled = true;
        }
        this.defaultConfiguration.clearVariables();
        List<String> variables = this.defaultConfiguration.variables(false);
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.layers[i] == null) {
                throw new IllegalStateException("Encountered null layer during initialization for layer " + i + ": " + this.layerWiseConfigurations.getConf(i).getLayer().getClass().getSimpleName() + " initialization returned null layer?");
            }
            for (String s : this.layers[i].conf().variables()) {
                variables.add(i + "_" + s);
            }
        }
        if (this.solver == null) {
            try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                this.solver.initOptimizer();
            }
        }
        this.synchronizeIterEpochCounts();
    }

    public void setGradientsAccumulator(GradientsAccumulator accumulator) {
        if (!this.isInitCalled()) {
            this.init();
        }
        this.solver.getOptimizer().setGradientsAccumulator(accumulator);
    }

    public boolean isInitCalled() {
        return this.initCalled;
    }

    public void initGradientsView() {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            if (this.layers == null) {
                this.init();
            }
            int nLayers = this.layers.length;
            int paramLength = 0;
            int[] nParamsPerLayer = new int[nLayers];
            for (int i = 0; i < nLayers; ++i) {
                NeuralNetConfiguration conf = this.layerWiseConfigurations.getConf(i);
                nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf);
                paramLength += nParamsPerLayer[i];
            }
            if (paramLength > 0) {
                this.flattenedGradients = Nd4j.zeros((int[])new int[]{1, paramLength}, (char)'f');
            }
            int backpropParamsSoFar = 0;
            for (int i = 0; i < this.layers.length; ++i) {
                if (nParamsPerLayer[i] == 0) continue;
                INDArray thisLayerGradView = this.flattenedGradients.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)backpropParamsSoFar, (int)(backpropParamsSoFar + nParamsPerLayer[i]))});
                this.layers[i].setBackpropGradientsViewArray(thisLayerGradView);
                backpropParamsSoFar += nParamsPerLayer[i];
            }
        }
    }

    @Deprecated
    public INDArray activate(INDArray input) {
        return this.output(input);
    }

    public INDArray activationFromPrevLayer(int curr, INDArray input, boolean training, LayerWorkspaceMgr mgr) {
        if (this.getLayerWiseConfigurations().getInputPreProcess(curr) != null) {
            input = this.getLayerWiseConfigurations().getInputPreProcess(curr).preProcess(input, this.getInputMiniBatchSize(), mgr);
        }
        INDArray ret = this.layers[curr].activate(input, training, mgr);
        return ret;
    }

    public INDArray activateSelectedLayers(int from, int to, INDArray input) {
        if (input == null) {
            throw new IllegalStateException("Unable to perform activation; no input found");
        }
        if (from < 0 || from >= this.layers.length || from >= to) {
            throw new IllegalStateException("Unable to perform activation; FROM is out of layer space");
        }
        if (to < 1 || to >= this.layers.length) {
            throw new IllegalStateException("Unable to perform activation; TO is out of layer space");
        }
        LayerWorkspaceMgr mgr = LayerWorkspaceMgr.noWorkspaces();
        INDArray res = input;
        for (int l = from; l <= to; ++l) {
            res = this.activationFromPrevLayer(l, res, false, mgr);
        }
        return res;
    }

    public List<INDArray> feedForward(INDArray input, boolean train) {
        this.setInput(input);
        return this.feedForward(train);
    }

    public List<INDArray> feedForward(boolean train) {
        return this.ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, this.layers.length - 1, this.input, this.mask, null, true);
    }

    public List<INDArray> feedForward(boolean train, boolean clearInputs) {
        return this.ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, this.layers.length - 1, this.input, this.mask, null, clearInputs);
    }

    public List<INDArray> feedForwardToLayer(int layerNum, INDArray input) {
        return this.ffToLayerActivationsDetached(false, FwdPassType.STANDARD, false, layerNum, input, this.mask, null, true);
    }

    public List<INDArray> feedForwardToLayer(int layerNum, INDArray input, boolean train) {
        int layerVertexIdx = this.layers[layerNum].getIndex();
        return this.ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layerVertexIdx, input, this.mask, null, true);
    }

    public List<INDArray> feedForwardToLayer(int layerNum, boolean train) {
        return this.ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layerNum, this.input, this.mask, null, true);
    }

    protected void validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, int layerIdx, boolean isPreprocessor, String op) {
        try {
            mgr.validateArrayLocation(arrayType, array, false, layerIdx > 0);
        }
        catch (ND4JWorkspaceException e) {
            String layerName = this.layers[layerIdx].conf().getLayer().getLayerName();
            String clazz = isPreprocessor ? this.layerWiseConfigurations.getInputPreProcess(layerIdx).getClass().getName() : this.layers[layerIdx].getClass().getName();
            throw new IllegalStateException(op + ": array (" + (Object)((Object)arrayType) + ") workspace validation failed (" + (isPreprocessor ? "preprocessor" : "layer ") + layerIdx + (layerName != null ? " - layer name \"" + layerName + "\"" : "") + " - class: " + clazz + ") - array is defined in incorrect workspace", e);
        }
    }

    protected List<INDArray> ffToLayerActivationsDetached(boolean train, @NonNull FwdPassType fwdPassType, boolean storeLastForTBPTT, int layerIndex, @NonNull INDArray input, INDArray fMask, INDArray lMask, boolean clearInputs) {
        LayerWorkspaceMgr workspaceMgr;
        WorkspaceMode wsm;
        if (fwdPassType == null) {
            throw new NullPointerException("fwdPassType");
        }
        if (input == null) {
            throw new NullPointerException("input");
        }
        this.setInput(input);
        this.setLayerMaskArrays(fMask, lMask);
        WorkspaceUtils.assertNoWorkspacesOpen((String)"Expected no workspace active in ffToLayerActivationsDetached");
        WorkspaceMode workspaceMode = wsm = train ? this.layerWiseConfigurations.getTrainingWorkspaceMode() : this.layerWiseConfigurations.getInferenceWorkspaceMode();
        if (wsm == WorkspaceMode.NONE) {
            workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
        } else {
            workspaceMgr = LayerWorkspaceMgr.builder().noWorkspaceFor(ArrayType.ACTIVATIONS).with(ArrayType.INPUT, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            if (input.isAttached()) {
                workspaceMgr.setNoLeverageOverride(input.data().getParentWorkspace().getId());
            }
            if (!clearInputs) {
                workspaceMgr.setScopedOutFor(ArrayType.INPUT);
            }
        }
        ArrayList<INDArray> out = new ArrayList<INDArray>();
        out.add(workspaceMgr.leverageTo(ArrayType.INPUT, input));
        for (int i = 0; i <= layerIndex; ++i) {
            try (MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM);){
                if (this.getLayerWiseConfigurations().getInputPreProcess(i) != null) {
                    input = this.getLayerWiseConfigurations().getInputPreProcess(i).preProcess(input, this.getInputMiniBatchSize(), workspaceMgr);
                    this.validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, true, "Feed forward to layer (inference)");
                }
                if (fwdPassType == FwdPassType.STANDARD) {
                    input = this.layers[i].activate(input, train, workspaceMgr);
                } else if (fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) {
                    if (this.layers[i] instanceof RecurrentLayer) {
                        input = ((RecurrentLayer)this.layers[i]).rnnActivateUsingStoredState(input, train, storeLastForTBPTT, workspaceMgr);
                    } else if (this.layers[i] instanceof MultiLayerNetwork) {
                        List<INDArray> temp = ((MultiLayerNetwork)this.layers[i]).rnnActivateUsingStoredState(input, train, storeLastForTBPTT);
                        input = temp.get(temp.size() - 1);
                    } else {
                        input = this.layers[i].activate(input, train, workspaceMgr);
                    }
                } else {
                    throw new IllegalStateException("Forward pass type not supported for this method: " + (Object)((Object)fwdPassType));
                }
                this.validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, false, "Feed forward to layer (inference)");
                out.add(input);
            }
            if (!clearInputs) continue;
            this.layers[i].clear();
        }
        return out;
    }

    protected List<INDArray> ffToLayerActivationsInWs(int layerIndex, @NonNull FwdPassType fwdPassType, boolean storeLastForTBPTT, @NonNull INDArray input, INDArray fMask, INDArray lMask) {
        LayerWorkspaceMgr workspaceMgr;
        if (fwdPassType == null) {
            throw new NullPointerException("fwdPassType");
        }
        if (input == null) {
            throw new NullPointerException("input");
        }
        this.setInput(input);
        this.setLayerMaskArrays(fMask, lMask);
        if (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) {
            WorkspaceUtils.assertNoWorkspacesOpen((String)"Expected no workspace active in ffToLayerActivationsInWs when training workspace is set to NONE");
            workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
        } else {
            workspaceMgr = LayerWorkspaceMgr.builder().with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            if (input.isAttached()) {
                workspaceMgr.setNoLeverageOverride(input.data().getParentWorkspace().getId());
            }
            if (this.layerWiseConfigurations.getCacheMode() != CacheMode.NONE) {
                workspaceMgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG);
                workspaceMgr.setWorkspace(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG);
            }
            WorkspaceUtils.assertOpenAndActive((String)WS_ALL_LAYERS_ACT, (String)"ffToLayerActivationsInWs method requires workspace WS_ALL_LAYERS_ACT to be open");
        }
        ArrayList<INDArray> out = new ArrayList<INDArray>();
        out.add(workspaceMgr.leverageTo(ArrayType.INPUT, input));
        for (int i = 0; i <= layerIndex; ++i) {
            try (MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM);){
                if (this.getLayerWiseConfigurations().getInputPreProcess(i) != null) {
                    input = this.getLayerWiseConfigurations().getInputPreProcess(i).preProcess(input, this.getInputMiniBatchSize(), workspaceMgr);
                    this.validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, true, "Feed forward to layer (training)");
                }
                if (fwdPassType == FwdPassType.STANDARD) {
                    input = this.layers[i].activate(input, true, workspaceMgr);
                } else if (fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) {
                    if (this.layers[i] instanceof RecurrentLayer) {
                        input = ((RecurrentLayer)this.layers[i]).rnnActivateUsingStoredState(input, true, storeLastForTBPTT, workspaceMgr);
                    } else if (this.layers[i] instanceof MultiLayerNetwork) {
                        List<INDArray> temp = ((MultiLayerNetwork)this.layers[i]).rnnActivateUsingStoredState(input, true, storeLastForTBPTT);
                        input = temp.get(temp.size() - 1);
                    } else {
                        input = this.layers[i].activate(input, true, workspaceMgr);
                    }
                } else {
                    throw new IllegalStateException("FwdPassType not supported for this method: " + (Object)((Object)fwdPassType));
                }
                if (input == null) {
                    throw new IllegalStateException("Layer " + i + " returned null activations");
                }
                this.validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, false, "Feed forward to layer (training)");
                this.validateArrayWorkspaces(workspaceMgr, this.layers[i].input(), ArrayType.INPUT, i, false, "Feed forward to layer (training)");
                out.add(input);
                continue;
            }
        }
        return out;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected INDArray outputOfLayerDetached(boolean train, @NonNull FwdPassType fwdPassType, int layerIndex, @NonNull INDArray input, INDArray featureMask, INDArray labelsMask) {
        LayerWorkspaceMgr mgrOdd;
        LayerWorkspaceMgr mgrEven;
        WorkspaceMode wsm;
        if (fwdPassType == null) {
            throw new NullPointerException("fwdPassType");
        }
        if (input == null) {
            throw new NullPointerException("input");
        }
        this.setInput(input);
        this.setLayerMaskArrays(featureMask, labelsMask);
        WorkspaceUtils.assertNoWorkspacesOpen((String)"Expected no workspace active in outputOfLayerDetached");
        WorkspaceMode workspaceMode = wsm = train ? this.layerWiseConfigurations.getTrainingWorkspaceMode() : this.layerWiseConfigurations.getInferenceWorkspaceMode();
        if (wsm == WorkspaceMode.NONE) {
            mgrOdd = mgrEven = LayerWorkspaceMgr.noWorkspaces();
        } else {
            mgrEven = LayerWorkspaceMgr.builder().with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.ACTIVATIONS, WS_LAYER_ACT_1, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.INPUT, WS_LAYER_ACT_2, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            mgrOdd = LayerWorkspaceMgr.builder().with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.ACTIVATIONS, WS_LAYER_ACT_2, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.INPUT, WS_LAYER_ACT_1, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
        }
        MemoryWorkspace wsActCloseNext = null;
        MemoryWorkspace temp = null;
        MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        try {
            for (int i = 0; i <= layerIndex; ++i) {
                LayerWorkspaceMgr mgr;
                LayerWorkspaceMgr layerWorkspaceMgr = mgr = i % 2 == 0 ? mgrEven : mgrOdd;
                if (i == 0 && wsm != WorkspaceMode.NONE) {
                    mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG);
                }
                try (MemoryWorkspace wsFFWorking = mgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM);){
                    temp = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS);
                    temp.setPreviousWorkspace(initialWorkspace);
                    if (i == 0 && input.isAttached()) {
                        mgr.setNoLeverageOverride(input.data().getParentWorkspace().getId());
                    }
                    if (this.getLayerWiseConfigurations().getInputPreProcess(i) != null) {
                        input = this.getLayerWiseConfigurations().getInputPreProcess(i).preProcess(input, this.getInputMiniBatchSize(), mgr);
                        this.validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, true, "Output of layer (inference)");
                    }
                    if (i == layerIndex) {
                        mgr.setScopedOutFor(ArrayType.ACTIVATIONS);
                    }
                    if (fwdPassType == FwdPassType.STANDARD) {
                        input = this.layers[i].activate(input, train, mgr);
                    } else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
                        input = this.layers[i] instanceof RecurrentLayer ? ((RecurrentLayer)this.layers[i]).rnnTimeStep(input, mgr) : (this.layers[i] instanceof MultiLayerNetwork ? ((MultiLayerNetwork)this.layers[i]).rnnTimeStep(input) : this.layers[i].activate(input, false, mgr));
                    } else {
                        throw new IllegalArgumentException("Unsupported forward pass type for this method: " + (Object)((Object)fwdPassType));
                    }
                    this.layers[i].clear();
                    this.validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, false, "Output of layer (inference)");
                    if (wsActCloseNext != null) {
                        wsActCloseNext.close();
                    }
                    wsActCloseNext = temp;
                    temp = null;
                }
                if (i != 0 || wsm == WorkspaceMode.NONE) continue;
                mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, this.WS_LAYER_ACT_X_CONFIG);
            }
        }
        finally {
            if (wsActCloseNext != null) {
                wsActCloseNext.close();
            }
            if (temp != null) {
                while (temp.isScopeActive()) {
                    temp.close();
                }
            }
            Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
            WorkspaceUtils.assertNoWorkspacesOpen((String)"Expected no workspace active at the end of outputOfLayerDetached");
        }
        return input;
    }

    public List<INDArray> feedForward() {
        return this.feedForward(false);
    }

    public List<INDArray> feedForward(INDArray input) {
        if (input == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        }
        this.setInput(input);
        return this.feedForward();
    }

    public List<INDArray> feedForward(INDArray input, INDArray featuresMask, INDArray labelsMask) {
        this.setLayerMaskArrays(featuresMask, labelsMask);
        List<INDArray> list = this.feedForward(input);
        this.clearLayerMaskArrays();
        return list;
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair((Object)this.gradient(), (Object)this.score());
    }

    @Override
    public MultiLayerNetwork clone() {
        Updater u;
        INDArray updaterState;
        MultiLayerConfiguration conf = this.layerWiseConfigurations.clone();
        MultiLayerNetwork ret = new MultiLayerNetwork(conf);
        ret.init(this.params().dup(), false);
        if (this.solver != null && (updaterState = (u = this.getUpdater()).getStateViewArray()) != null) {
            ret.getUpdater().setStateViewArray(ret, updaterState.dup(), false);
        }
        if (this.hasAFrozenLayer()) {
            org.deeplearning4j.nn.api.Layer[] clonedLayers = ret.getLayers();
            for (int i = 0; i < this.layers.length; ++i) {
                if (!(this.layers[i] instanceof FrozenLayer)) continue;
                clonedLayers[i] = new FrozenLayer(ret.getLayer(i));
            }
            ret.setLayers(clonedLayers);
        }
        return ret;
    }

    protected boolean hasAFrozenLayer() {
        for (int i = 0; i < this.layers.length - 1; ++i) {
            if (!(this.layers[i] instanceof FrozenLayer)) continue;
            return true;
        }
        return false;
    }

    public INDArray params(boolean backwardOnly) {
        if (backwardOnly) {
            return this.params();
        }
        ArrayList<INDArray> params = new ArrayList<INDArray>();
        for (org.deeplearning4j.nn.api.Layer layer : this.getLayers()) {
            INDArray layerParams = layer.params();
            if (layerParams == null) continue;
            params.add(layerParams);
        }
        return Nd4j.toFlattened((char)'f', params);
    }

    @Override
    public INDArray params() {
        return this.flattenedParams;
    }

    @Override
    public void setParams(INDArray params) {
        if (this.flattenedParams == params) {
            return;
        }
        if (this.flattenedParams != null && params.length() == this.flattenedParams.length()) {
            if (params != this.flattenedParams) {
                this.flattenedParams.assign(params);
            }
        } else {
            if (this.flattenedParams == null) {
                this.flattenedParams = params.dup();
            }
            int idx = 0;
            for (int i = 0; i < this.getLayers().length; ++i) {
                org.deeplearning4j.nn.api.Layer layer = this.getLayer(i);
                int range = layer.numParams();
                if (range <= 0) continue;
                INDArray get = params.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)idx, (int)(range + idx))});
                layer.setParams(get);
                idx += range;
            }
        }
    }

    @Override
    public void setParamsViewArray(INDArray params) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public INDArray getGradientsViewArray() {
        return this.flattenedGradients;
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradients) {
        int paramsSoFar = 0;
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            if (layer.numParams() == 0) continue;
            layer.setBackpropGradientsViewArray(gradients.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)paramsSoFar, (int)(paramsSoFar + layer.numParams()))}));
            paramsSoFar += layer.numParams();
        }
    }

    @Override
    public int numParams() {
        if (this.isInitCalled()) {
            return this.numParams(false);
        }
        log.info("Model is not initialized. Initialize net with init()");
        return 0;
    }

    @Override
    public int numParams(boolean backwards) {
        int length = 0;
        for (int i = 0; i < this.layers.length; ++i) {
            length += this.layers[i].numParams(backwards);
        }
        return length;
    }

    @Override
    public double f1Score(DataSet data) {
        return this.f1Score(data.getFeatures(), data.getLabels());
    }

    public void fit(@NonNull DataSetIterator iterator, int numEpochs) {
        if (iterator == null) {
            throw new NullPointerException("iterator");
        }
        Preconditions.checkArgument((numEpochs > 0 ? 1 : 0) != 0, (String)"Number of epochs much be > 0. Got numEpochs = %s", (int)numEpochs);
        Preconditions.checkArgument((numEpochs == 1 || iterator.resetSupported() ? 1 : 0) != 0, (String)"Cannot perform multiple epochs training usingiterator thas does not support resetting (iterator.resetSupported() returned false)");
        for (int i = 0; i < numEpochs; ++i) {
            this.fit(iterator);
        }
    }

    @Override
    public void fit(DataSetIterator iterator) {
        DataSetIterator iter;
        boolean destructable = false;
        if (iterator.asyncSupported()) {
            iter = new AsyncDataSetIterator(iterator, Math.min(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), this.layerWiseConfigurations.getTrainingWorkspaceMode() != WorkspaceMode.NONE);
            destructable = true;
        } else {
            iter = iterator;
        }
        for (TrainingListener tl : this.trainingListeners) {
            tl.onEpochStart(this);
        }
        LayerWorkspaceMgr workspaceMgr = this.getLayerWiseConfigurations().getTrainingWorkspaceMode() == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).build();
        if (this.layerWiseConfigurations.isBackprop()) {
            this.update(TaskUtils.buildTask((DataSetIterator)iter));
            if (!iter.hasNext() && iter.resetSupported()) {
                iter.reset();
            }
            long time1 = System.currentTimeMillis();
            while (iter.hasNext()) {
                org.nd4j.linalg.dataset.DataSet next = (org.nd4j.linalg.dataset.DataSet)iter.next();
                long time2 = System.currentTimeMillis();
                this.lastEtlTime.set(time2 - time1);
                if (next.getFeatureMatrix() == null || next.getLabels() == null) break;
                boolean hasMaskArrays = next.hasMaskArrays();
                if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
                    this.doTruncatedBPTT(next.getFeatureMatrix(), next.getLabels(), next.getFeaturesMaskArray(), next.getLabelsMaskArray(), workspaceMgr);
                } else {
                    if (hasMaskArrays) {
                        this.setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
                    }
                    this.setInput(next.getFeatureMatrix());
                    this.setLabels(next.getLabels());
                    if (this.solver == null) {
                        try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                        }
                    }
                    this.solver.optimize(workspaceMgr);
                }
                if (hasMaskArrays) {
                    this.clearLayerMaskArrays();
                }
                time1 = System.currentTimeMillis();
            }
        }
        if (!this.trainingListeners.isEmpty()) {
            for (TrainingListener tl : this.trainingListeners) {
                tl.onEpochEnd(this);
            }
        }
        this.clearLayersStates();
        if (destructable) {
            ((AsyncDataSetIterator)iter).shutdown();
        }
        this.incrementEpochCount();
    }

    public Pair<Gradient, INDArray> calculateGradients(@NonNull INDArray features, @NonNull INDArray label, INDArray fMask, INDArray labelMask) {
        LayerWorkspaceMgr mgr;
        if (features == null) {
            throw new NullPointerException("features");
        }
        if (label == null) {
            throw new NullPointerException("label");
        }
        this.setInput(features);
        this.setLabels(label);
        this.setLayerMaskArrays(fMask, labelMask);
        if (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) {
            mgr = LayerWorkspaceMgr.noWorkspaces();
        } else {
            mgr = LayerWorkspaceMgr.builder().with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            if (this.layerWiseConfigurations.getCacheMode() != null) {
                mgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG);
            }
        }
        try (MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS);){
            List<INDArray> activations = this.ffToLayerActivationsInWs(this.layers.length - 2, FwdPassType.STANDARD, false, this.input, this.mask, fMask);
            if (!this.trainingListeners.isEmpty()) {
                for (TrainingListener tl : this.trainingListeners) {
                    tl.onForwardPass((Model)this, activations);
                }
            }
            INDArray inputToOutputLayer = activations.get(activations.size() - 1);
            if (this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1) != null) {
                inputToOutputLayer = this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1).preProcess(inputToOutputLayer, this.getInputMiniBatchSize(), mgr);
            }
            this.getOutputLayer().setInput(inputToOutputLayer, mgr);
            Pair<Gradient, INDArray> p = this.calcBackpropGradients(null, true, false, true);
            if (p.getSecond() != null) {
                p.setSecond((Object)((INDArray)p.getSecond()).detach());
            }
            Pair<Gradient, INDArray> pair = p;
            return pair;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected Pair<Gradient, INDArray> calcBackpropGradients(INDArray epsilon, boolean withOutputLayer, boolean tbptt, boolean returnInputActGrad) {
        LayerWorkspaceMgr mgrOdd;
        LayerWorkspaceMgr mgrEven;
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        DefaultGradient gradient = new DefaultGradient(this.flattenedGradients);
        if (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) {
            mgrOdd = mgrEven = LayerWorkspaceMgr.noWorkspaces();
            WorkspaceUtils.assertNoWorkspacesOpen((String)"Expected no workspace active in calcBackpropGradients when training workspace is set to none");
        } else {
            mgrEven = LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.ACTIVATION_GRAD, WS_LAYER_ACT_1, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            mgrOdd = LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.ACTIVATION_GRAD, WS_LAYER_ACT_2, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            if (epsilon == null) {
                WorkspaceUtils.assertOpenActiveAndCurrent((String)WS_ALL_LAYERS_ACT, (String)"calcBackpropGradients method requires workspace WS_ALL_LAYERS_ACT to be open when workspaces are used");
            }
        }
        int numLayers = this.getnLayers();
        LinkedList<Triple> gradientList = new LinkedList<Triple>();
        Pair<Gradient, INDArray> currPair = null;
        MemoryWorkspace wsActGradCloseNext = null;
        MemoryWorkspace wsActGradTemp = null;
        MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        try {
            for (int i = this.layers.length - 1; i >= 0; --i) {
                LayerWorkspaceMgr workspaceMgr;
                if (this.layers[i] instanceof FrozenLayer) {
                    break;
                }
                LayerWorkspaceMgr layerWorkspaceMgr = workspaceMgr = i % 2 == 0 ? mgrEven : mgrOdd;
                if (withOutputLayer && i == this.layers.length - 1) {
                    if (!(this.getOutputLayer() instanceof IOutputLayer)) {
                        log.warn("Warning: final layer isn't output layer. You cannot use backprop without an output layer.");
                        Pair<Gradient, INDArray> pair = null;
                        return pair;
                    }
                    IOutputLayer outputLayer = (IOutputLayer)this.getOutputLayer();
                    if (this.labels == null && outputLayer.needsLabels()) {
                        throw new IllegalStateException("No labels found");
                    }
                    outputLayer.setLabels(this.labels);
                }
                wsActGradTemp = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD);
                try (MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM);){
                    INDArray eps;
                    wsActGradTemp.setPreviousWorkspace(initialWorkspace);
                    wsBPWorking.setPreviousWorkspace(initialWorkspace);
                    INDArray iNDArray = eps = i == this.layers.length - 1 ? epsilon : (INDArray)currPair.getRight();
                    currPair = !tbptt ? this.layers[i].backpropGradient(eps, workspaceMgr) : (this.layers[i] instanceof RecurrentLayer ? ((RecurrentLayer)this.layers[i]).tbpttBackpropGradient((INDArray)currPair.getSecond(), this.layerWiseConfigurations.getTbpttBackLength(), workspaceMgr) : this.layers[i].backpropGradient((INDArray)currPair.getSecond(), workspaceMgr));
                    if (currPair.getSecond() != null) {
                        this.validateArrayWorkspaces(workspaceMgr, (INDArray)currPair.getSecond(), ArrayType.ACTIVATION_GRAD, numLayers - 1, false, "Backprop");
                    }
                    for (Map.Entry<String, INDArray> entry : ((Gradient)currPair.getFirst()).gradientForVariable().entrySet()) {
                        String origName = entry.getKey();
                        String multiGradientKey = String.valueOf(i) + "_" + origName;
                        gradientList.addLast(new Triple((Object)multiGradientKey, (Object)entry.getValue(), (Object)((Gradient)currPair.getFirst()).flatteningOrderForVariable(origName)));
                    }
                    if (this.getLayerWiseConfigurations().getInputPreProcess(i) != null) {
                        currPair = new Pair(currPair.getFirst(), (Object)this.layerWiseConfigurations.getInputPreProcess(i).backprop((INDArray)currPair.getSecond(), this.getInputMiniBatchSize(), workspaceMgr));
                        if (i > 0 && currPair.getSecond() != null) {
                            this.validateArrayWorkspaces(workspaceMgr, (INDArray)currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i, true, "Backprop");
                        }
                    }
                    if (i == 0) {
                        if (returnInputActGrad && currPair.getSecond() != null) {
                            currPair.setSecond((Object)((INDArray)currPair.getSecond()).detach());
                        } else {
                            currPair.setSecond(null);
                        }
                    }
                    if (wsActGradCloseNext != null) {
                        wsActGradCloseNext.close();
                    }
                    wsActGradCloseNext = wsActGradTemp;
                    wsActGradTemp = null;
                    continue;
                }
            }
        }
        finally {
            if (wsActGradCloseNext != null) {
                wsActGradCloseNext.close();
            }
            if (wsActGradTemp != null) {
                wsActGradTemp.close();
            }
            Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
        }
        if (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) {
            WorkspaceUtils.assertNoWorkspacesOpen((String)"Expected no workspace active in calcBackpropGradients when training workspace is set to none");
        } else if (epsilon == null) {
            WorkspaceUtils.assertOpenActiveAndCurrent((String)WS_ALL_LAYERS_ACT, (String)"calcBackpropGradients: WS_ALL_LAYERS_ACT is no longer the currently open/active workspace");
        }
        for (Triple triple : gradientList) {
            gradient.setGradientFor((String)triple.getFirst(), (INDArray)triple.getSecond(), (Character)triple.getThird());
        }
        return new Pair((Object)gradient, currPair.getSecond());
    }

    protected void doTruncatedBPTT(INDArray input, INDArray labels, INDArray featuresMaskArray, INDArray labelsMaskArray, LayerWorkspaceMgr workspaceMgr) {
        if (input.rank() != 3 || labels.rank() != 3) {
            log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got " + Arrays.toString(input.shape()) + "\tand labels with shape " + Arrays.toString(labels.shape()));
            return;
        }
        if (input.size(2) != labels.size(2)) {
            log.warn("Input and label time series have different lengths: {} input length, {} label length", (Object)input.size(2), (Object)labels.size(2));
            return;
        }
        int fwdLen = this.layerWiseConfigurations.getTbpttFwdLength();
        this.update(TaskUtils.buildTask((INDArray)input, (INDArray)labels));
        int timeSeriesLength = input.size(2);
        int nSubsets = timeSeriesLength / fwdLen;
        if (timeSeriesLength % fwdLen != 0) {
            ++nSubsets;
        }
        this.rnnClearPreviousState();
        for (int i = 0; i < nSubsets; ++i) {
            int startTimeIdx = i * fwdLen;
            int endTimeIdx = startTimeIdx + fwdLen;
            if (endTimeIdx > timeSeriesLength) {
                endTimeIdx = timeSeriesLength;
            }
            INDArray[] subsets = this.getSubsetsForTbptt(startTimeIdx, endTimeIdx, input, labels, featuresMaskArray, labelsMaskArray);
            this.setInput(subsets[0]);
            this.setLabels(subsets[1]);
            this.setLayerMaskArrays(subsets[2], subsets[3]);
            if (this.solver == null) {
                try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                }
            }
            this.solver.optimize(workspaceMgr);
            this.updateRnnStateWithTBPTTState();
        }
        this.rnnClearPreviousState();
        this.clearLayerMaskArrays();
    }

    private INDArray[] getSubsetsForTbptt(int startTimeIdx, int endTimeIdx, INDArray input, INDArray labels, INDArray fMask, INDArray lMask) {
        INDArray[] out = new INDArray[4];
        out[0] = input.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
        out[1] = labels.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
        if (fMask != null) {
            out[2] = fMask.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
        }
        if (lMask != null) {
            out[3] = lMask.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)startTimeIdx, (int)endTimeIdx)});
        }
        return out;
    }

    public void updateRnnStateWithTBPTTState() {
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.layers[i] instanceof RecurrentLayer) {
                RecurrentLayer l = (RecurrentLayer)this.layers[i];
                l.rnnSetPreviousState(l.rnnGetTBPTTState());
                continue;
            }
            if (!(this.layers[i] instanceof MultiLayerNetwork)) continue;
            ((MultiLayerNetwork)this.layers[i]).updateRnnStateWithTBPTTState();
        }
    }

    @Override
    public Collection<TrainingListener> getListeners() {
        return this.trainingListeners;
    }

    public Collection<TrainingListener> getTrainingListeners() {
        return this.trainingListeners;
    }

    @Override
    public void setListeners(Collection<TrainingListener> listeners) {
        if (this.layers == null) {
            this.init();
        }
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            layer.setListeners(listeners);
        }
        if (this.solver != null) {
            this.solver.setListeners(listeners);
        }
        this.trainingListeners.clear();
        if (listeners != null) {
            this.trainingListeners.addAll(listeners);
        }
    }

    @Override
    public void addListeners(TrainingListener ... listeners) {
        Collections.addAll(this.trainingListeners, listeners);
        if (this.solver != null) {
            this.solver.setListeners(this.trainingListeners);
        }
    }

    @Override
    public void setListeners(TrainingListener ... listeners) {
        ArrayList<TrainingListener> cListeners = new ArrayList<TrainingListener>();
        if (listeners != null && listeners.length > 0) {
            for (TrainingListener i : listeners) {
                if (i == null) continue;
                cListeners.add(i);
            }
        }
        this.setListeners(cListeners);
    }

    public void finetune() {
        if (!this.layerWiseConfigurations.isBackprop()) {
            log.warn("Warning: finetune is not applied.");
            return;
        }
        if (!(this.getOutputLayer() instanceof IOutputLayer)) {
            log.warn("Output layer not instance of output layer returning.");
            return;
        }
        if (this.flattenedGradients == null) {
            this.initGradientsView();
        }
        if (this.labels == null) {
            throw new IllegalStateException("No labels found");
        }
        log.info("Finetune phase");
        IOutputLayer output = (IOutputLayer)this.getOutputLayer();
        if (output.conf().getOptimizationAlgo() == OptimizationAlgorithm.HESSIAN_FREE) {
            throw new UnsupportedOperationException();
        }
        this.feedForward();
        output.fit(output.input(), this.labels);
    }

    @Override
    public int[] predict(INDArray d) {
        INDArray output = this.output(d, Layer.TrainingMode.TEST);
        int[] ret = new int[d.size(0)];
        if (d.isRowVectorOrScalar()) {
            ret[0] = Nd4j.getBlasWrapper().iamax(output);
        } else {
            for (int i = 0; i < ret.length; ++i) {
                ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
            }
        }
        return ret;
    }

    @Override
    public List<String> predict(DataSet dataSet) {
        int[] intRet = this.predict(dataSet.getFeatures());
        ArrayList<String> ret = new ArrayList<String>();
        for (int i = 0; i < intRet.length; ++i) {
            ret.add(i, dataSet.getLabelName(intRet[i]));
        }
        return ret;
    }

    @Override
    public INDArray labelProbabilities(INDArray examples) {
        List<INDArray> feed = this.feedForward(examples);
        IOutputLayer o = (IOutputLayer)this.getOutputLayer();
        return o.labelProbabilities(feed.get(feed.size() - 1));
    }

    @Override
    public void fit(INDArray data, INDArray labels) {
        this.fit(data, labels, null, null);
    }

    public void fit(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask) {
        if (this.numParams() == 0) {
            return;
        }
        this.setInput(features);
        this.setLabels(labels);
        this.setLayerMaskArrays(featuresMask, labelsMask);
        this.update(TaskUtils.buildTask((INDArray)features, (INDArray)labels));
        LayerWorkspaceMgr workspaceMgr = this.layerWiseConfigurations.getTrainingWorkspaceMode() == null ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).build();
        if (this.layerWiseConfigurations.isBackprop()) {
            if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
                this.doTruncatedBPTT(features, labels, featuresMask, labelsMask, workspaceMgr);
            } else {
                if (this.solver == null) {
                    try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                        this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
                    }
                }
                this.solver.optimize(workspaceMgr);
            }
        }
        this.clearLayerMaskArrays();
        this.clearLayersStates();
    }

    @Override
    public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr) {
        throw new UnsupportedOperationException("Not supported: use pretrainLayer");
    }

    @Override
    public void fit(DataSet data) {
        this.fit(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArray(), data.getLabelsMaskArray());
    }

    @Override
    public void fit(INDArray examples, int[] labels) {
        OutputLayer layerConf = (OutputLayer)this.getOutputLayer().conf().getLayer();
        this.fit(examples, FeatureUtil.toOutcomeMatrix((int[])labels, (int)layerConf.getNOut()));
    }

    public INDArray output(INDArray input, Layer.TrainingMode train) {
        return this.output(input, train == Layer.TrainingMode.TRAIN);
    }

    public INDArray output(INDArray input, boolean train) {
        return this.outputOfLayerDetached(train, FwdPassType.STANDARD, this.layers.length - 1, input, null, null);
    }

    public INDArray output(INDArray input, boolean train, INDArray featuresMask, INDArray labelsMask) {
        return this.outputOfLayerDetached(train, FwdPassType.STANDARD, this.layers.length - 1, input, featuresMask, labelsMask);
    }

    public INDArray output(INDArray input) {
        return this.output(input, Layer.TrainingMode.TEST);
    }

    public INDArray output(DataSetIterator iterator, boolean train) {
        ArrayList<INDArray> outList = new ArrayList<INDArray>();
        while (iterator.hasNext()) {
            org.nd4j.linalg.dataset.DataSet next = (org.nd4j.linalg.dataset.DataSet)iterator.next();
            INDArray features = next.getFeatures();
            if (features == null) continue;
            INDArray fMask = next.getFeaturesMaskArray();
            INDArray lMask = next.getLabelsMaskArray();
            outList.add(this.output(features, train, fMask, lMask));
        }
        return Nd4j.concat((int)0, (INDArray[])outList.toArray(new INDArray[outList.size()]));
    }

    public INDArray output(DataSetIterator iterator) {
        return this.output(iterator, false);
    }

    public INDArray reconstruct(INDArray x, int layerNum) {
        List<INDArray> forward = this.feedForward(x);
        return forward.get(layerNum - 1);
    }

    public void printConfiguration() {
        StringBuilder sb = new StringBuilder();
        int count = 0;
        for (NeuralNetConfiguration conf : this.getLayerWiseConfigurations().getConfs()) {
            sb.append(" Layer " + count++ + " conf " + conf);
        }
        log.info(sb.toString());
    }

    public void update(MultiLayerNetwork network) {
        NeuralNetConfiguration neuralNetConfiguration = this.defaultConfiguration = network.defaultConfiguration != null ? network.defaultConfiguration.clone() : null;
        if (network.input != null) {
            this.setInput(network.input.dup());
        }
        this.labels = network.labels;
        if (network.layers != null) {
            this.layers = new org.deeplearning4j.nn.api.Layer[network.layers.length];
            for (int i = 0; i < this.layers.length; ++i) {
                this.layers[i] = network.layers[i].clone();
            }
        } else {
            this.layers = null;
        }
        if (network.solver != null) {
            INDArray updaterView = network.getUpdater().getStateViewArray();
            if (updaterView != null) {
                MultiLayerUpdater newUpdater = new MultiLayerUpdater(this);
                newUpdater.setStateViewArray(this, updaterView.dup(), false);
                this.setUpdater(newUpdater);
            }
        } else {
            this.solver = null;
        }
    }

    @Override
    public double f1Score(INDArray input, INDArray labels) {
        this.feedForward(input);
        this.setLabels(labels);
        Evaluation eval = new Evaluation();
        eval.eval(labels, this.labelProbabilities(input));
        return eval.f1();
    }

    @Override
    public int numLabels() {
        return this.labels.columns();
    }

    public double score(org.nd4j.linalg.dataset.DataSet data) {
        return this.score(data, false);
    }

    public double score(org.nd4j.linalg.dataset.DataSet data, boolean training) {
        double score;
        boolean hasMaskArray = data.hasMaskArrays();
        if (hasMaskArray) {
            this.setLayerMaskArrays(data.getFeaturesMaskArray(), data.getLabelsMaskArray());
        }
        if (!(this.getOutputLayer() instanceof IOutputLayer)) {
            throw new IllegalStateException("Cannot calculate score if final layer is not an instance of IOutputLayer. Final layer is of type: " + this.getOutputLayer().getClass());
        }
        WorkspaceMode wsm = training ? this.layerWiseConfigurations.getTrainingWorkspaceMode() : this.layerWiseConfigurations.getInferenceWorkspaceMode();
        LayerWorkspaceMgr mgr = wsm == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).noWorkspaceFor(ArrayType.ACTIVATIONS).build();
        INDArray inputToOutputLayer = this.outputOfLayerDetached(training, FwdPassType.STANDARD, this.layers.length - 2, data.getFeatures(), data.getFeaturesMaskArray(), data.getLabelsMaskArray());
        IOutputLayer ol = (IOutputLayer)this.getOutputLayer();
        if (this.getLayerWiseConfigurations().getInputPreProcess(this.layers.length - 1) != null) {
            inputToOutputLayer = this.getLayerWiseConfigurations().getInputPreProcess(this.layers.length - 1).preProcess(inputToOutputLayer, data.getFeatures().size(0), mgr);
        }
        ol.setInput(inputToOutputLayer, mgr);
        ol.setLabels(data.getLabels());
        try (MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM);){
            score = ol.computeScore(this.calcL1(true), this.calcL2(true), training, mgr);
        }
        if (hasMaskArray) {
            this.clearLayerMaskArrays();
        }
        this.clearLayersStates();
        return score;
    }

    public INDArray scoreExamples(DataSetIterator iter, boolean addRegularizationTerms) {
        ArrayList<INDArray> out = new ArrayList<INDArray>();
        while (iter.hasNext()) {
            out.add(this.scoreExamples((org.nd4j.linalg.dataset.DataSet)iter.next(), addRegularizationTerms));
        }
        return Nd4j.toFlattened((char)'f', out);
    }

    public INDArray scoreExamples(org.nd4j.linalg.dataset.DataSet data, boolean addRegularizationTerms) {
        IOutputLayer ol;
        INDArray inputLast = this.outputOfLayerDetached(false, FwdPassType.STANDARD, this.layers.length - 2, data.getFeatures(), data.getFeaturesMaskArray(), data.getLabelsMaskArray());
        this.setLabels(data.getLabels());
        this.setLayerMaskArrays(data.getFeaturesMaskArray(), data.getLabelsMaskArray());
        LayerWorkspaceMgr mgr = LayerWorkspaceMgr.noWorkspaces();
        if (this.getOutputLayer() instanceof IOutputLayer) {
            ol = (IOutputLayer)this.getOutputLayer();
            if (this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1) != null) {
                inputLast = this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1).preProcess(inputLast, data.getFeatures().size(0), mgr);
            }
        } else {
            throw new UnsupportedOperationException("Cannot calculate score with respect to labels without an OutputLayer");
        }
        ol.setLabels(data.getLabels());
        ol.setInput(inputLast, mgr);
        double l1 = addRegularizationTerms ? this.calcL1(true) : 0.0;
        double l2 = addRegularizationTerms ? this.calcL2(true) : 0.0;
        INDArray out = ol.computeScoreForExamples(l1, l2, mgr);
        this.clearLayersStates();
        this.clearLayerMaskArrays();
        return out;
    }

    @Override
    public void fit() {
        this.fit(this.input, this.labels);
    }

    @Override
    public void update(INDArray gradient, String paramType) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public double score() {
        return this.score;
    }

    public void setScore(double score) {
        this.score = score;
    }

    @Override
    public void computeGradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr) {
        this.computeGradientAndScore();
    }

    public void computeGradientAndScore() {
        block45: {
            LayerWorkspaceMgr mgr;
            if (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) {
                mgr = LayerWorkspaceMgr.noWorkspaces();
            } else {
                mgr = LayerWorkspaceMgr.builder().with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
                if (this.layerWiseConfigurations.getCacheMode() != null) {
                    mgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG);
                }
            }
            boolean tbptt = this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT;
            FwdPassType fwdType = tbptt ? FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE : FwdPassType.STANDARD;
            this.synchronizeIterEpochCounts();
            try (MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS);){
                List<INDArray> activations = this.ffToLayerActivationsInWs(this.layers.length - 2, fwdType, tbptt, this.input, this.mask, null);
                if (!this.trainingListeners.isEmpty()) {
                    for (TrainingListener tl : this.trainingListeners) {
                        tl.onForwardPass((Model)this, activations);
                    }
                }
                INDArray inputToOutputLayer = activations.get(activations.size() - 1);
                if (this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1) != null) {
                    inputToOutputLayer = this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1).preProcess(inputToOutputLayer, this.getInputMiniBatchSize(), mgr);
                }
                this.getOutputLayer().setInput(inputToOutputLayer, mgr);
                Pair<Gradient, INDArray> pair = this.calcBackpropGradients(null, true, false, false);
                Gradient gradient = this.gradient = pair == null ? null : (Gradient)pair.getFirst();
                if (!(this.getOutputLayer() instanceof IOutputLayer)) {
                    throw new DL4JException("Cannot calculate gradient and score with respect to labels: final layer is not an IOutputLayer");
                }
                try (MemoryWorkspace wsFF = mgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM);){
                    this.score = ((IOutputLayer)this.getOutputLayer()).computeScore(this.calcL1(true), this.calcL2(true), true, mgr);
                }
                if (this.trainingListeners.isEmpty()) break block45;
                var10_12 = null;
                try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    for (TrainingListener tl : this.trainingListeners) {
                        tl.onBackwardPass(this);
                    }
                }
                catch (Throwable throwable) {
                    var10_12 = throwable;
                    throw throwable;
                }
            }
        }
        this.getOutputLayer().clearNoiseWeightParams();
    }

    @Override
    public void accumulateScore(double accum) {
    }

    @Override
    public void clear() {
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            layer.clear();
        }
        this.input = null;
        this.labels = null;
        this.solver = null;
    }

    @Override
    public void applyConstraints(int iteration, int epoch) {
        for (org.deeplearning4j.nn.api.Layer l : this.layers) {
            l.applyConstraints(iteration, epoch);
        }
    }

    public void setInput(INDArray input) {
        this.input = input;
        if (this.layers == null) {
            this.init();
        }
        if (input != null) {
            if (input.length() == 0) {
                throw new IllegalArgumentException("Invalid input: length 0 (shape: " + Arrays.toString(input.shape()) + ")");
            }
            this.setInputMiniBatchSize(input.size(0));
        }
    }

    @Override
    public void setInput(INDArray input, LayerWorkspaceMgr mgr) {
        throw new UnsupportedOperationException("Not supported");
    }

    public org.deeplearning4j.nn.api.Layer getOutputLayer() {
        org.deeplearning4j.nn.api.Layer ret = this.getLayers()[this.getLayers().length - 1];
        if (ret instanceof FrozenLayerWithBackprop) {
            ret = ((FrozenLayerWithBackprop)ret).getInsideLayer();
        }
        return ret;
    }

    public void setParameters(INDArray params) {
        this.setParams(params);
    }

    public NeuralNetConfiguration getDefaultConfiguration() {
        return this.defaultConfiguration;
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public INDArray getInput() {
        return this.input;
    }

    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    public int getnLayers() {
        return this.layerWiseConfigurations.getConfs().size();
    }

    public synchronized org.deeplearning4j.nn.api.Layer[] getLayers() {
        return this.layers;
    }

    public org.deeplearning4j.nn.api.Layer getLayer(int i) {
        Preconditions.checkArgument((i >= 0 && i < this.layers.length ? 1 : 0) != 0, (String)"Invalid layer index: layer index must be 0 to %s (inclusive), got index %s", (int)(this.layers.length - 1), (int)i);
        return this.layers[i];
    }

    public org.deeplearning4j.nn.api.Layer getLayer(String name) {
        return this.layerMap.get(name);
    }

    public List<String> getLayerNames() {
        return new ArrayList<String>(this.layerMap.keySet());
    }

    public void setLayers(org.deeplearning4j.nn.api.Layer[] layers) {
        this.layers = layers;
    }

    public INDArray getMask() {
        return this.mask;
    }

    public void setMask(INDArray mask) {
        this.mask = mask;
    }

    @Override
    public INDArray getMaskArray() {
        return this.mask;
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public void clearNoiseWeightParams() {
        for (org.deeplearning4j.nn.api.Layer l : this.layers) {
            l.clearNoiseWeightParams();
        }
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
        if (maskArray == null) {
            for (int i = 0; i < this.layers.length; ++i) {
                this.layers[i].feedForwardMaskArray(null, null, minibatchSize);
            }
        } else {
            for (int i = 0; i < this.layers.length; ++i) {
                Pair<INDArray, MaskState> p;
                InputPreProcessor preProcessor = this.getLayerWiseConfigurations().getInputPreProcess(i);
                if (preProcessor != null) {
                    p = preProcessor.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
                    if (p != null) {
                        maskArray = (INDArray)p.getFirst();
                        currentMaskState = (MaskState)((Object)p.getSecond());
                    } else {
                        maskArray = null;
                        currentMaskState = null;
                    }
                }
                if ((p = this.layers[i].feedForwardMaskArray(maskArray, currentMaskState, minibatchSize)) != null) {
                    maskArray = (INDArray)p.getFirst();
                    currentMaskState = (MaskState)((Object)p.getSecond());
                    continue;
                }
                maskArray = null;
                currentMaskState = null;
            }
        }
        return new Pair((Object)maskArray, (Object)currentMaskState);
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.MULTILAYER;
    }

    public INDArray activate(Layer.TrainingMode training) {
        return this.output(this.input, training == Layer.TrainingMode.TRAIN);
    }

    public INDArray activate(INDArray input, Layer.TrainingMode training) {
        return this.output(input, training == Layer.TrainingMode.TRAIN);
    }

    @Override
    public org.deeplearning4j.nn.api.Layer transpose() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        if (this.getOutputLayer() instanceof IOutputLayer) {
            throw new UnsupportedOperationException("Cannot calculate gradients based on epsilon with OutputLayer");
        }
        return this.calcBackpropGradients(epsilon, false, false, true);
    }

    @Override
    public void setIndex(int index) {
        this.layerIndex = index;
    }

    @Override
    public int getIndex() {
        return this.layerIndex;
    }

    @Override
    public int getIterationCount() {
        return this.getLayerWiseConfigurations().getIterationCount();
    }

    @Override
    public int getEpochCount() {
        return this.getLayerWiseConfigurations().getEpochCount();
    }

    @Override
    public void setIterationCount(int iterationCount) {
        this.getLayerWiseConfigurations().setIterationCount(iterationCount);
    }

    @Override
    public void setEpochCount(int epochCount) {
        this.getLayerWiseConfigurations().setEpochCount(epochCount);
    }

    @Override
    public double calcL2(boolean backpropParamsOnly) {
        double l2 = 0.0;
        for (int i = 0; i < this.layers.length; ++i) {
            l2 += this.layers[i].calcL2(backpropParamsOnly);
        }
        return l2;
    }

    @Override
    public double calcL1(boolean backpropParamsOnly) {
        double l1 = 0.0;
        for (int i = 0; i < this.layers.length; ++i) {
            l1 += this.layers[i].calcL1(backpropParamsOnly);
        }
        return l1;
    }

    @Override
    public void update(Gradient gradient) {
        if (gradient.gradient().length() != this.numParams(true)) {
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + this.numParams(true));
        }
        for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
            String key = entry.getKey();
            INDArray val = entry.getValue();
            int idx = key.indexOf(95);
            if (idx == -1) {
                throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
            }
            Integer layerId = Integer.parseInt(key.substring(0, idx));
            String paramType = key.substring(idx + 1);
            this.gradient.gradientForVariable().put(key, val);
            this.layers[layerId].update(val, paramType);
        }
        this.setBackpropGradientsViewArray(gradient.gradient());
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr mgr) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr mgr) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void setInputMiniBatchSize(int size) {
        if (this.layers != null) {
            for (org.deeplearning4j.nn.api.Layer l : this.layers) {
                l.setInputMiniBatchSize(size);
            }
        }
    }

    @Override
    public int getInputMiniBatchSize() {
        if (!this.conf().isMiniBatch()) {
            return 1;
        }
        return this.input.size(0);
    }

    @Override
    public void setMaskArray(INDArray maskArray) {
        throw new UnsupportedOperationException();
    }

    public INDArray rnnTimeStep(INDArray input) {
        boolean inputIs2d = input.rank() == 2;
        INDArray out = this.outputOfLayerDetached(false, FwdPassType.RNN_TIMESTEP, this.layers.length - 1, input, null, null);
        if (inputIs2d && out.rank() == 3 && this.layers[this.layers.length - 1].type() == Layer.Type.RECURRENT) {
            return out.tensorAlongDimension(0, new int[]{1, 0});
        }
        return out;
    }

    public Map<String, INDArray> rnnGetPreviousState(int layer) {
        if (layer < 0 || layer >= this.layers.length) {
            throw new IllegalArgumentException("Invalid layer number");
        }
        if (!(this.layers[layer] instanceof RecurrentLayer)) {
            throw new IllegalArgumentException("Layer is not an RNN layer");
        }
        return ((RecurrentLayer)this.layers[layer]).rnnGetPreviousState();
    }

    public void rnnSetPreviousState(int layer, Map<String, INDArray> state) {
        if (layer < 0 || layer >= this.layers.length) {
            throw new IllegalArgumentException("Invalid layer number");
        }
        if (!(this.layers[layer] instanceof RecurrentLayer)) {
            throw new IllegalArgumentException("Layer is not an RNN layer");
        }
        RecurrentLayer r = (RecurrentLayer)this.layers[layer];
        r.rnnSetPreviousState(state);
    }

    public void rnnClearPreviousState() {
        if (this.layers == null) {
            return;
        }
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.layers[i] instanceof RecurrentLayer) {
                ((RecurrentLayer)this.layers[i]).rnnClearPreviousState();
                continue;
            }
            if (!(this.layers[i] instanceof MultiLayerNetwork)) continue;
            ((MultiLayerNetwork)this.layers[i]).rnnClearPreviousState();
        }
    }

    public List<INDArray> rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT) {
        return this.ffToLayerActivationsDetached(training, FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE, storeLastForTBPTT, this.layers.length - 1, input, this.mask, null, false);
    }

    public synchronized Updater getUpdater() {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
            this.solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this));
        }
        return this.solver.getOptimizer().getUpdater();
    }

    public void setUpdater(Updater updater) {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(this.conf()).listeners(this.getListeners()).model(this).build();
        }
        this.solver.getOptimizer().setUpdater(updater);
    }

    public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) {
        if (featuresMaskArray != null) {
            this.feedForwardMaskArray(featuresMaskArray, MaskState.Active, featuresMaskArray.size(0));
        }
        if (labelsMaskArray != null) {
            if (!(this.getOutputLayer() instanceof IOutputLayer)) {
                return;
            }
            this.layers[this.layers.length - 1].setMaskArray(labelsMaskArray);
        }
    }

    public void clearLayerMaskArrays() {
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            layer.setMaskArray(null);
        }
    }

    public Evaluation evaluate(DataSetIterator iterator) {
        return this.evaluate(iterator, null);
    }

    public RegressionEvaluation evaluateRegression(DataSetIterator iterator) {
        return ((RegressionEvaluation[])this.doEvaluation(iterator, new RegressionEvaluation[]{new RegressionEvaluation(iterator.totalOutcomes())}))[0];
    }

    public ROC evaluateROC(DataSetIterator iterator) {
        return this.evaluateROC(iterator, 0);
    }

    public ROC evaluateROC(DataSetIterator iterator, int rocThresholdSteps) {
        return ((ROC[])this.doEvaluation(iterator, new ROC[]{new ROC(rocThresholdSteps)}))[0];
    }

    public ROCMultiClass evaluateROCMultiClass(DataSetIterator iterator) {
        return this.evaluateROCMultiClass(iterator, 0);
    }

    public ROCMultiClass evaluateROCMultiClass(DataSetIterator iterator, int rocThresholdSteps) {
        return ((ROCMultiClass[])this.doEvaluation(iterator, new ROCMultiClass[]{new ROCMultiClass(rocThresholdSteps)}))[0];
    }

    @Override
    public <T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator, T ... evaluations) {
        boolean useRnnSegments;
        if (!iterator.hasNext() && iterator.resetSupported()) {
            iterator.reset();
        }
        DataSetIterator iter = iterator.asyncSupported() ? new AsyncDataSetIterator(iterator, 2, true) : iterator;
        WorkspaceMode cMode = this.layerWiseConfigurations.getTrainingWorkspaceMode();
        this.layerWiseConfigurations.setTrainingWorkspaceMode(this.layerWiseConfigurations.getInferenceWorkspaceMode());
        boolean bl = useRnnSegments = this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT;
        while (iter.hasNext()) {
            org.nd4j.linalg.dataset.DataSet next = (org.nd4j.linalg.dataset.DataSet)iter.next();
            if (next.getFeatureMatrix() == null || next.getLabels() == null) continue;
            INDArray features = next.getFeatures();
            INDArray labels = next.getLabels();
            INDArray fMask = next.getFeaturesMaskArray();
            INDArray lMask = next.getLabelsMaskArray();
            if (!useRnnSegments) {
                INDArray out = this.outputOfLayerDetached(false, FwdPassType.STANDARD, this.layers.length - 1, features, fMask, lMask);
                try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                    for (T evaluation : evaluations) {
                        evaluation.eval(labels, out, lMask);
                    }
                }
            } else {
                this.rnnClearPreviousState();
                int fwdLen = this.layerWiseConfigurations.getTbpttFwdLength();
                int tsLength = features.size(2);
                int nSubsets = tsLength / fwdLen;
                if (tsLength % fwdLen != 0) {
                    ++nSubsets;
                }
                for (int i = 0; i < nSubsets; ++i) {
                    int startTimeIdx = i * fwdLen;
                    int endTimeIdx = Math.min(startTimeIdx + fwdLen, tsLength);
                    INDArray[] subsets = this.getSubsetsForTbptt(startTimeIdx, endTimeIdx, features, labels, fMask, lMask);
                    this.setLayerMaskArrays(subsets[2], subsets[3]);
                    INDArray outSub = this.rnnTimeStep(subsets[0]);
                    try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                        for (T evaluation : evaluations) {
                            evaluation.eval(subsets[1], outSub, subsets[3]);
                        }
                        continue;
                    }
                }
            }
            this.clearLayersStates();
        }
        if (iterator.asyncSupported()) {
            ((AsyncDataSetIterator)iter).shutdown();
        }
        this.layerWiseConfigurations.setTrainingWorkspaceMode(cMode);
        return evaluations;
    }

    public Evaluation evaluate(DataSetIterator iterator, List<String> labelsList) {
        return this.evaluate(iterator, labelsList, 1);
    }

    @Override
    public INDArray updaterState() {
        return this.getUpdater() != null ? this.getUpdater().getStateViewArray() : null;
    }

    @Override
    public void fit(MultiDataSet dataSet) {
        INDArray lMask;
        INDArray fMask;
        INDArray labels;
        INDArray features;
        if (dataSet.getFeatures().length == 1 && dataSet.getLabels().length == 1) {
            features = dataSet.getFeatures(0);
            labels = dataSet.getLabels(0);
            fMask = null;
            lMask = null;
            if (dataSet.getFeaturesMaskArrays() != null) {
                fMask = dataSet.getFeaturesMaskArrays()[0];
            }
            if (dataSet.getFeaturesMaskArrays() != null) {
                lMask = dataSet.getLabelsMaskArrays()[0];
            }
        } else {
            throw new DL4JInvalidInputException("MultiLayerNetwork can't handle MultiDataSet with more than 1 features or labels array.Please consider use of ComputationGraph");
        }
        org.nd4j.linalg.dataset.DataSet ds = new org.nd4j.linalg.dataset.DataSet(features, labels, fMask, lMask);
        this.fit((DataSet)ds);
    }

    public void fit(@NonNull MultiDataSetIterator iterator, int numEpochs) {
        if (iterator == null) {
            throw new NullPointerException("iterator");
        }
        Preconditions.checkArgument((numEpochs > 0 ? 1 : 0) != 0, (String)"Number of epochs much be > 0. Got numEpochs = %s", (int)numEpochs);
        Preconditions.checkArgument((numEpochs == 1 || iterator.resetSupported() ? 1 : 0) != 0, (String)"Cannot perform multiple epochs training usingiterator thas does not support resetting (iterator.resetSupported() returned false)");
        for (int i = 0; i < numEpochs; ++i) {
            this.fit(iterator);
        }
    }

    @Override
    public void fit(MultiDataSetIterator iterator) {
        this.fit((DataSetIterator)new MultiDataSetWrapperIterator(iterator));
    }

    @Override
    public <T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator, T[] evaluations) {
        return this.doEvaluation((DataSetIterator)new MultiDataSetWrapperIterator(iterator), (IEvaluation[])evaluations);
    }

    public Evaluation evaluate(DataSetIterator iterator, List<String> labelsList, int topN) {
        if (this.layers == null || !(this.getOutputLayer() instanceof IOutputLayer)) {
            throw new IllegalStateException("Cannot evaluate network with no output layer");
        }
        if (labelsList == null) {
            labelsList = iterator.getLabels();
        }
        Evaluation e = new Evaluation((List<String>)labelsList, topN);
        this.doEvaluation(iterator, new Evaluation[]{e});
        return e;
    }

    protected void update(Task task) {
        if (!this.initDone) {
            this.initDone = true;
            Heartbeat heartbeat = Heartbeat.getInstance();
            task = ModelSerializer.taskByModel(this);
            Environment env = EnvironmentUtils.buildEnvironment();
            heartbeat.reportEvent(Event.STANDALONE, env, task);
        }
    }

    public String summary() {
        return this.summary(null);
    }

    public String summary(InputType inputType) {
        String ret = "\n";
        ret = ret + StringUtils.repeat((String)"=", (int)250);
        ret = ret + "\n";
        ret = inputType != null ? ret + String.format("%-40s%-10s%-12s%-40s%-75s%-75s\n", "LayerName (LayerType)", "nIn,nOut", "TotalParams", "ParamsShape", "InputShape", "OutputShape") : ret + String.format("%-40s%-10s%-12s%-40s\n", "LayerName (LayerType)", "nIn,nOut", "TotalParams", "ParamsShape");
        ret = ret + StringUtils.repeat((String)"=", (int)250);
        ret = ret + "\n";
        int frozenParams = 0;
        for (org.deeplearning4j.nn.api.Layer currentLayer : this.getLayers()) {
            String name = currentLayer.conf().getLayer().getLayerName();
            if (name == null) {
                name = String.valueOf(currentLayer.getIndex());
            }
            String paramShape = "-";
            String in = "-";
            String out = "-";
            String[] classNameArr = currentLayer.getClass().getName().split("\\.");
            String className = classNameArr[classNameArr.length - 1];
            String paramCount = String.valueOf(currentLayer.numParams());
            String inShape = "";
            String outShape = "";
            if (inputType != null) {
                InputPreProcessor preProcessor = this.getLayerWiseConfigurations().getInputPreProcess(currentLayer.getIndex());
                inShape = inputType.toString();
                if (preProcessor != null) {
                    inputType = preProcessor.getOutputType(inputType);
                    inShape = inShape + "--> " + inputType.toString();
                }
                InputType outType = currentLayer.conf().getLayer().getOutputType(currentLayer.getIndex(), inputType);
                outShape = outType.toString();
                inputType = outType;
            }
            if (currentLayer.numParams() > 0) {
                paramShape = "";
                in = String.valueOf(((FeedForwardLayer)currentLayer.conf().getLayer()).getNIn());
                out = String.valueOf(((FeedForwardLayer)currentLayer.conf().getLayer()).getNOut());
                Set<String> paraNames = currentLayer.paramTable().keySet();
                for (String aP : paraNames) {
                    String paramS = ArrayUtils.toString((Object)currentLayer.paramTable().get(aP).shape());
                    paramShape = paramShape + aP + ":" + paramS + ", ";
                }
                paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString();
            }
            if (currentLayer instanceof FrozenLayer) {
                frozenParams += currentLayer.numParams();
                classNameArr = ((FrozenLayer)currentLayer).getInsideLayer().getClass().getName().split("\\.");
                className = "Frozen " + classNameArr[classNameArr.length - 1];
            }
            ret = inputType != null ? ret + String.format("%-40s%-10s%-12s%-40s%-75s%-75s", name + " (" + className + ")", in + "," + out, paramCount, paramShape, inShape, outShape) : ret + String.format("%-40s%-12s%-10s%-40s", name + " (" + className + ")", in + "," + out, paramCount, paramShape);
            ret = ret + "\n";
        }
        ret = ret + StringUtils.repeat((String)"-", (int)250);
        ret = ret + String.format("\n%30s %d", "Total Parameters: ", this.params().length());
        ret = ret + String.format("\n%30s %d", "Trainable Parameters: ", this.params().length() - frozenParams);
        ret = ret + String.format("\n%30s %d", "Frozen Parameters: ", frozenParams);
        ret = ret + "\n";
        ret = ret + StringUtils.repeat((String)"=", (int)250);
        ret = ret + "\n";
        return ret;
    }

    protected void clearLayersStates() {
        for (org.deeplearning4j.nn.api.Layer layer : this.layers) {
            layer.clear();
            layer.clearNoiseWeightParams();
        }
    }

    public void incrementEpochCount() {
        this.layerWiseConfigurations.setEpochCount(this.layerWiseConfigurations.getEpochCount() + 1);
    }

    protected void synchronizeIterEpochCounts() {
        int currIter = this.getIterationCount();
        int currEpoch = this.getEpochCount();
        for (org.deeplearning4j.nn.api.Layer l : this.layers) {
            l.setIterationCount(currIter);
            l.setEpochCount(currEpoch);
        }
    }

    public void save(File f) throws IOException {
        this.save(f, true);
    }

    public void save(File f, boolean saveUpdater) throws IOException {
        ModelSerializer.writeModel((Model)this, f, saveUpdater);
    }

    public static MultiLayerNetwork load(File f, boolean loadUpdater) throws IOException {
        return ModelSerializer.restoreMultiLayerNetwork(f, loadUpdater);
    }

    public ComputationGraph toComputationGraph() {
        return NetworkUtils.toComputationGraph(this);
    }

    public void setLearningRate(double newLr) {
        NetworkUtils.setLearningRate(this, newLr);
    }

    public void setLearningRate(ISchedule newLr) {
        NetworkUtils.setLearningRate(this, newLr);
    }

    public void setLearningRate(int layerNumber, double newLr) {
        NetworkUtils.setLearningRate(this, layerNumber, newLr);
    }

    public void setLearningRate(int layerNumber, ISchedule newLr) {
        NetworkUtils.setLearningRate(this, layerNumber, newLr);
    }

    public int layerSize(int layer) {
        if (layer < 0 || layer > this.layers.length) {
            throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and " + (this.layers.length - 1) + " inclusive");
        }
        Layer conf = this.layers[layer].conf().getLayer();
        if (conf == null || !(conf instanceof FeedForwardLayer)) {
            return 0;
        }
        FeedForwardLayer ffl = (FeedForwardLayer)conf;
        return ffl.getNOut();
    }

    public int layerInputSize(int layer) {
        if (layer < 0 || layer > this.layers.length) {
            throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and " + (this.layers.length - 1) + " inclusive");
        }
        Layer conf = this.layers[layer].conf().getLayer();
        if (conf == null || !(conf instanceof FeedForwardLayer)) {
            return 0;
        }
        FeedForwardLayer ffl = (FeedForwardLayer)conf;
        return ffl.getNIn();
    }

    public boolean equals(Object obj) {
        if (obj == null) {
            return false;
        }
        if (obj instanceof MultiLayerNetwork) {
            MultiLayerNetwork network = (MultiLayerNetwork)obj;
            boolean paramsEquals = network.params().equals(this.params());
            boolean confEquals = this.getLayerWiseConfigurations().equals(network.getLayerWiseConfigurations());
            boolean updaterEquals = this.getUpdater().equals(network.getUpdater());
            return paramsEquals && confEquals && updaterEquals;
        }
        return false;
    }

    public void setInitDone(boolean initDone) {
        this.initDone = initDone;
    }

    public INDArray getFlattenedGradients() {
        return this.flattenedGradients;
    }
}

