/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff;

import com.google.flatbuffers.FlatBufferBuilder;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.execution.conf.ExecutionMode;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.impl.HistoryListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.ArgumentInterceptor;
import org.nd4j.autodiff.samediff.NameScope;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiffConditional;
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
import org.nd4j.autodiff.samediff.config.FitConfig;
import org.nd4j.autodiff.samediff.config.OutputConfig;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.autodiff.samediff.internal.DataTypesSession;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.samediff.ops.SDBaseOps;
import org.nd4j.autodiff.samediff.ops.SDBitwise;
import org.nd4j.autodiff.samediff.ops.SDCNN;
import org.nd4j.autodiff.samediff.ops.SDImage;
import org.nd4j.autodiff.samediff.ops.SDLoss;
import org.nd4j.autodiff.samediff.ops.SDMath;
import org.nd4j.autodiff.samediff.ops.SDNN;
import org.nd4j.autodiff.samediff.ops.SDRNN;
import org.nd4j.autodiff.samediff.ops.SDRandom;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.autodiff.util.TrainingUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatConfiguration;
import org.nd4j.graph.FlatGraph;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatVariable;
import org.nd4j.graph.IntPair;
import org.nd4j.graph.UpdaterState;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.If;
import org.nd4j.linalg.api.ops.impl.controlflow.While;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.collection.IntArrayKeyMap;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.AtomicDouble;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.nd4j.linalg.util.ND4JFileUtils;
import org.nd4j.shade.guava.base.Predicate;
import org.nd4j.shade.guava.base.Predicates;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Maps;
import org.nd4j.shade.guava.collect.Table;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ConstantInitScheme;
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.GraphDef;

public class SameDiff
extends SDBaseOps {
    private static final Logger log;
    protected static final String GRAD_FN_KEY = "grad";
    private final Map<String, Variable> variables = new LinkedHashMap<String, Variable>();
    private final Map<String, SameDiffOp> ops = new LinkedHashMap<String, SameDiffOp>();
    private final Map<Long, InferenceSession> sessions = new ConcurrentHashMap<Long, InferenceSession>();
    private final Map<String, DeviceLocalNDArray> constantArrays = new ConcurrentHashMap<String, DeviceLocalNDArray>();
    private final Map<String, DeviceLocalNDArray> variablesArrays = new ConcurrentHashMap<String, DeviceLocalNDArray>();
    private final Map<Long, Map<String, INDArray>> placeholdersPerThread = new ConcurrentHashMap<Long, Map<String, INDArray>>();
    private final List<String> lossVariables = new ArrayList<String>();
    private final List<Listener> listeners = new ArrayList<Listener>();
    private final List<NameScope> nameScopes = new ArrayList<NameScope>();
    private TrainingConfig trainingConfig;
    private boolean initializedTraining;
    private Map<String, GradientUpdater> updaterMap;
    private Map<String, String> baseNameForFunctionInstanceId;
    private DifferentialFunctionFactory functionFactory;
    @Deprecated
    private Map<String, long[]> variableNameToShape;
    @Deprecated
    private Map<String, SDVariable> forwardVarForGrad;
    private int variableId = 0;
    public final SDMath math = new SDMath(this);
    public final SDRandom random = new SDRandom(this);
    public final SDNN nn = new SDNN(this);
    public final SDCNN cnn = new SDCNN(this);
    public final SDRNN rnn = new SDRNN(this);
    public final SDLoss loss = new SDLoss(this);
    public final SDImage image = new SDImage(this);
    public final SDBitwise bitwise = new SDBitwise(this);
    private Map<String, List<String>> propertiesToResolve;
    private Map<String, Map<String, Object>> propertiesForFunction;
    @Deprecated
    private Map<String, long[]> placeHolderOriginalShapes;
    private Map<String, SameDiffFunctionDefinition> sameDiffFunctionDefinitionMap;
    private Map<String, SameDiff> sameDiffFunctionInstances;
    private Set<String> placeHolderFunctions;
    private static Map<String, Method> opMethods;
    private Table<String, String, String> fieldVariableResolutionMapping;
    private transient AtomicBoolean wasRegistered = new AtomicBoolean(false);
    private boolean debugMode;
    private Map<int[], Op> opsForResult;
    private boolean resolvedVariables = false;
    private Stack<ArgumentInterceptor> argumentInterceptors = new Stack();
    private Set<ArgumentInterceptor> pausedArgumentInterceptors = new HashSet<ArgumentInterceptor>();
    private Set<String> blockNames = new HashSet<String>();
    boolean logExecution = true;
    private SameDiff parent;
    private SameDiff child;

    public SDMath math() {
        return this.math;
    }

    public SDRandom random() {
        return this.random;
    }

    public SDNN nn() {
        return this.nn;
    }

    public SDCNN cnn() {
        return this.cnn;
    }

    public SDRNN rnn() {
        return this.rnn;
    }

    public SDLoss loss() {
        return this.loss;
    }

    public SDImage image() {
        return this.image;
    }

    public SDBitwise bitwise() {
        return this.bitwise;
    }

    public void updateVariableName(String varName, String withName) {
        DifferentialFunction func;
        SDVariable oldVarNameRef = this.getVariable(varName);
        Variable v = this.variables.remove(varName);
        String oldVarName = varName;
        oldVarNameRef.setVarName(withName);
        v.setName(withName);
        this.variables.put(withName, v);
        for (SameDiffOp sameDiffOp : this.ops.values()) {
            List<String> inputsToOp;
            List<String> outputsOfOp = sameDiffOp.getOutputsOfOp();
            if (outputsOfOp != null && !outputsOfOp.isEmpty()) {
                for (int i = 0; i < outputsOfOp.size(); ++i) {
                    if (!outputsOfOp.get(i).equals(oldVarName)) continue;
                    outputsOfOp.set(i, withName);
                }
            }
            if ((inputsToOp = sameDiffOp.getInputsToOp()) == null || inputsToOp.isEmpty()) continue;
            for (int i = 0; i < inputsToOp.size(); ++i) {
                if (!inputsToOp.get(i).equals(oldVarName)) continue;
                inputsToOp.set(i, withName);
            }
        }
        if (this.variableNameToShape.containsKey(oldVarName)) {
            long[] shape = this.variableNameToShape.remove(oldVarName);
            this.variableNameToShape.put(withName, shape);
        }
        if (this.forwardVarForGrad.containsKey(oldVarName)) {
            SDVariable forwardGrad = this.forwardVarForGrad.remove(oldVarName);
            this.forwardVarForGrad.put(withName, forwardGrad);
        }
        if (v.getInputsForOp() != null) {
            List<String> funcNames = v.getInputsForOp();
            for (String s : funcNames) {
                DifferentialFunction func2 = this.ops.get(s).getOp();
                if (!(func2 instanceof BaseOp)) continue;
                BaseOp baseOp = (BaseOp)func2;
                if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) {
                    baseOp.setXVertexId(withName);
                }
                if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) {
                    baseOp.setYVertexId(withName);
                }
                if (baseOp.getZVertexId() == null || !baseOp.getZVertexId().equals(oldVarName)) continue;
                baseOp.setZVertexId(withName);
            }
        }
        if (v.getOutputOfOp() != null && (func = this.ops.get(v.getOutputOfOp()).getOp()) instanceof BaseOp) {
            BaseOp baseOp = (BaseOp)func;
            if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) {
                baseOp.setXVertexId(withName);
            }
            if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) {
                baseOp.setYVertexId(withName);
            }
            if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(oldVarName)) {
                baseOp.setZVertexId(withName);
            }
        }
    }

    public SameDiff disableDebugging() {
        this.debugMode = false;
        return this;
    }

    public SameDiff enableDebugMode() {
        this.debugMode = true;
        return this;
    }

    @Override
    public DifferentialFunctionFactory f() {
        return this.functionFactory;
    }

    public void setListeners(Listener ... listeners) {
        this.listeners.clear();
        this.addListeners(listeners);
    }

    public void setListeners(Collection<? extends Listener> listeners) {
        this.listeners.clear();
        this.addListeners(listeners);
    }

    public void addListeners(Listener ... listeners) {
        this.addListeners(Arrays.asList(listeners));
    }

    public void addListeners(Collection<? extends Listener> listeners) {
        this.listeners.addAll(listeners);
    }

    public List<Listener> getListeners() {
        return this.listeners;
    }

    public String currentNameScope() {
        if (this.nameScopes.isEmpty()) {
            return null;
        }
        StringBuilder sb = new StringBuilder();
        boolean first = true;
        for (NameScope ns : this.nameScopes) {
            if (!first) {
                sb.append("/");
            }
            sb.append(ns.getName());
            first = false;
        }
        return sb.toString();
    }

    protected String nameWithScope(String name) {
        String scope = this.currentNameScope();
        if (scope == null) {
            return name;
        }
        if (!name.startsWith(scope + "/")) {
            return scope + "/" + name;
        }
        return name;
    }

    void addNameScope(NameScope nameScope) {
        this.nameScopes.add(nameScope);
    }

    void closeNameScope(NameScope nameScope) {
        Preconditions.checkState((!this.nameScopes.isEmpty() ? 1 : 0) != 0, (String)"Cannot close name scope: no name scopes are currently defined");
        Preconditions.checkState((boolean)this.nameScopes.get(this.nameScopes.size() - 1).equals(nameScope), (String)"Cannot close name scope %s: Name scopes must be closed in order. Current name scopes: \"%s\"", (Object)nameScope, (Object)this.currentNameScope());
        this.nameScopes.remove(this.nameScopes.size() - 1);
    }

    public NameScope withNameScope(String nameScope) {
        NameScope ns = new NameScope(this, nameScope);
        this.addNameScope(ns);
        return ns;
    }

    public List<SameDiffOp> getOpsInScope(NameScope scope) {
        ArrayList<SameDiffOp> ops = new ArrayList<SameDiffOp>();
        for (SameDiffOp v : this.ops.values()) {
            if (!v.getName().startsWith(scope.getName())) continue;
            ops.add(v);
        }
        return ops;
    }

    public List<SameDiffOp> getOpsInScope(String scope) {
        return this.getOpsInScope(new NameScope(this, scope));
    }

    public List<SDVariable> getVariablesInScope(NameScope scope) {
        ArrayList<SDVariable> vars = new ArrayList<SDVariable>();
        for (SDVariable v : this.variables()) {
            if (!v.getVarName().startsWith(scope.getName())) continue;
            vars.add(v);
        }
        return vars;
    }

    public List<SDVariable> getVariablesInScope(String scope) {
        return this.getVariablesInScope(new NameScope(this, scope));
    }

    public SDVariable invokeGraphOn(SameDiff sameDiff) {
        HashMap<Integer, Integer> thisVertexIdToNew = new HashMap<Integer, Integer>();
        int idx = 1;
        for (SDVariable var : this.variables()) {
            SDVariable clone = var.clone(this);
            SDVariable newVar = sameDiff.var(clone);
            if (var.getArr() != null && var.getVariableType() != VariableType.ARRAY) {
                sameDiff.associateArrayWithVariable(var.getArr(), newVar);
            }
            thisVertexIdToNew.put(idx, idx);
            clone.setSameDiff(sameDiff);
            ++idx;
        }
        HashMap<String, Integer> reverseMap = new HashMap<String, Integer>();
        int count = 0;
        for (Variable v : this.variables.values()) {
            reverseMap.put(v.getName(), count++);
        }
        LinkedHashMap<String, DifferentialFunction> newFunctions = new LinkedHashMap<String, DifferentialFunction>();
        for (SameDiffOp op : this.ops.values()) {
            DifferentialFunction function = op.getOp();
            DifferentialFunction clone = FlatBuffersMapper.cloneViaSerialize(this, function, reverseMap);
            clone.setSameDiff(sameDiff);
            clone.setOwnName(function.getOwnName());
            if (sameDiff.opExists(function.getOwnName())) {
                sameDiff.putOpForId(function.getOwnName(), function);
            }
            newFunctions.put(function.getOwnName(), clone);
            SDVariable[] argsForFunction = function.args();
            SDVariable[] outputsForFunction = function.outputVariables();
            sameDiff.addArgsFor(argsForFunction, clone);
            sameDiff.addOutgoingFor(outputsForFunction, function);
            for (SDVariable arg : clone.args()) {
                arg.setSameDiff(sameDiff);
            }
            for (SDVariable output : clone.outputVariables()) {
                output.setSameDiff(sameDiff);
            }
            sameDiff.ops.put(function.getOwnName(), op);
        }
        return sameDiff.variables().get(sameDiff.variables().size() - 1);
    }

    public boolean opExists(String id) {
        return this.ops.containsKey(id);
    }

    public DifferentialFunction getVariableOutputOp(String variableName) {
        Preconditions.checkState((boolean)this.variables.containsKey(variableName), (String)"No variable with name \"%s\" found in graph", (Object)variableName);
        if (this.variables.get(variableName).getOutputOfOp() == null) {
            return null;
        }
        return this.ops.get(this.variables.get(variableName).getOutputOfOp()).getOp();
    }

    public DifferentialFunction getOpById(@NonNull String id) {
        if (id == null) {
            throw new NullPointerException("id is marked @NonNull but is null");
        }
        if (!this.ops.containsKey(id)) {
            throw new ND4JIllegalStateException("No function with id " + id + " found!");
        }
        return this.ops.get(id).getOp();
    }

    public void putOpForId(String id, DifferentialFunction function) {
        if (this.ops.containsKey(id) && this.ops.get(id).getOp() == null) {
            throw new ND4JIllegalStateException("Function by id already exists!");
        }
        if (!this.ops.containsKey(id)) {
            this.ops.put(id, SameDiffOp.builder().name(id).op(function).build());
        }
    }

    public String[] getInputsForOp(DifferentialFunction function) {
        if (!this.ops.containsKey(function.getOwnName())) {
            throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
        }
        List<String> inputs = this.ops.get(function.getOwnName()).getInputsToOp();
        return inputs == null ? null : inputs.toArray(new String[inputs.size()]);
    }

    public String[] getOutputsForOp(DifferentialFunction function) {
        if (!this.ops.containsKey(function.getOwnName())) {
            throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
        }
        List<String> outputs = this.ops.get(function.getOwnName()).getOutputsOfOp();
        return outputs == null ? null : outputs.toArray(new String[outputs.size()]);
    }

    public SDVariable[] getOutputVariablesForOp(DifferentialFunction function) {
        String[] inputs = this.getOutputsForOp(function);
        if (inputs == null) {
            throw new ND4JIllegalStateException("No inputs found for function " + function);
        }
        SDVariable[] vars = new SDVariable[inputs.length];
        for (int i = 0; i < inputs.length; ++i) {
            vars[i] = this.getVariable(inputs[i]);
        }
        return vars;
    }

    public SDVariable[] getInputVariablesForOp(DifferentialFunction function) {
        String[] inputs = this.getInputsForOp(function);
        if (inputs == null) {
            throw new ND4JIllegalStateException("No inputs found for function " + function);
        }
        SDVariable[] vars = new SDVariable[inputs.length];
        for (int i = 0; i < inputs.length; ++i) {
            vars[i] = this.getVariable(inputs[i]);
            if (vars[i] != null) continue;
            throw new ND4JIllegalStateException("Found null variable at index " + i);
        }
        return vars;
    }

    public void setArrayForVariable(@NonNull String varName, @NonNull INDArray arr) {
        if (varName == null) {
            throw new NullPointerException("varName is marked @NonNull but is null");
        }
        if (arr == null) {
            throw new NullPointerException("arr is marked @NonNull but is null");
        }
        Preconditions.checkState((boolean)this.variables.containsKey(varName), (String)"No variable with name \"%s\" exists", (Object)varName);
        SDVariable v = this.getVariable(varName);
        if (v.isConstant()) {
            this.constantArrays.put(varName, new DeviceLocalNDArray(arr, true));
        } else if (v.getVariableType() == VariableType.VARIABLE) {
            this.variablesArrays.put(varName, new DeviceLocalNDArray(arr, true));
        } else if (v.isPlaceHolder()) {
            long tid = Thread.currentThread().getId();
            if (!this.placeholdersPerThread.containsKey(tid)) {
                this.placeholdersPerThread.put(tid, new HashMap());
            }
            this.placeholdersPerThread.get(tid).put(varName, arr);
        } else {
            throw new UnsupportedOperationException("Cannot set variable of type " + (Object)((Object)v.getVariableType()) + " using this method");
        }
    }

    public long[] getShapeForVarName(String varName) {
        if (this.arrayAlreadyExistsForVarName(varName)) {
            return this.getVariable(varName).getArr().shape();
        }
        return this.variableNameToShape.get(varName);
    }

    public LongShapeDescriptor getShapeDescriptorForVarName(String varName) {
        if (this.getVariable(varName).getArr() != null) {
            return this.getVariable(varName).getArr().shapeDescriptor();
        }
        return LongShapeDescriptor.fromShape(this.variableNameToShape.get(varName), Nd4j.dataType());
    }

    @Deprecated
    public void putShapeForVarName(String varName, long[] shape) {
        if (shape == null) {
            throw new ND4JIllegalStateException("Shape must not be null!");
        }
        if (this.variableNameToShape.containsKey(varName)) {
            throw new ND4JIllegalStateException("Shape for " + varName + " already exists!");
        }
        this.variableNameToShape.put(varName, shape);
    }

    public void putShapeForVarName(String varName, LongShapeDescriptor shape) {
        SDVariable v = this.getVariable(varName);
        this.putShapeForVarName(varName, shape.getShape());
        v.setDataType(shape.dataType());
    }

    @Deprecated
    public void putOrUpdateShapeForVarName(String varName, long[] shape, boolean clearArrayOnShapeMismatch) {
        Preconditions.checkNotNull((Object)shape, (String)"Cannot put null shape for variable: %s", (Object)varName);
        if (!this.variableNameToShape.containsKey(varName)) {
            this.putShapeForVarName(varName, shape);
        }
    }

    public boolean shapeAlreadyExistsForVarName(String varName) {
        return this.variableNameToShape.containsKey(varName) || this.arrayAlreadyExistsForVarName(varName);
    }

    public boolean arrayAlreadyExistsForVarName(String varName) {
        SDVariable var = this.getVariable(varName);
        switch (var.getVariableType()) {
            case VARIABLE: {
                return this.variablesArrays.containsKey(varName);
            }
            case ARRAY: {
                long tid = Thread.currentThread().getId();
                return this.sessions.containsKey(tid) && this.sessions.get(tid).contains(varName, "main", 0, null);
            }
            case CONSTANT: {
                return this.constantArrays.containsKey(varName);
            }
            case PLACEHOLDER: {
                return this.placeholdersPerThread.containsKey(Thread.currentThread().getId()) && this.placeholdersPerThread.get(Thread.currentThread().getId()).containsKey(varName);
            }
        }
        throw new RuntimeException("Unknown variable type: " + (Object)((Object)var.getVariableType()));
    }

    public INDArray getArrForVarName(@NonNull String varName) {
        if (varName == null) {
            throw new NullPointerException("varName is marked @NonNull but is null");
        }
        Preconditions.checkState((boolean)this.variables.containsKey(varName), (String)"No variable found with name \"%s\"", (Object)varName);
        SDVariable v = this.variables.get(varName).getVariable();
        switch (v.getVariableType()) {
            case VARIABLE: {
                if (!this.variablesArrays.containsKey(varName)) {
                    v.storeAndAllocateNewArray();
                }
                return this.variablesArrays.get(varName).get();
            }
            case CONSTANT: {
                if (!this.constantArrays.containsKey(varName)) {
                    return null;
                }
                return this.constantArrays.get(varName).get();
            }
            case ARRAY: {
                InferenceSession s = this.sessions.get(Thread.currentThread().getId());
                if (s == null) {
                    return null;
                }
                return (INDArray)s.get(varName, "main", 0, null, false);
            }
            case PLACEHOLDER: {
                long tid = Thread.currentThread().getId();
                if (this.placeholdersPerThread.get(tid) == null || !this.placeholdersPerThread.get(tid).containsKey(varName)) {
                    return null;
                }
                return this.placeholdersPerThread.get(tid).get(varName);
            }
        }
        throw new RuntimeException("Unknown variable type: " + (Object)((Object)v.getVariableType()));
    }

    public void associateArrayWithVariable(INDArray arr, @NonNull String variable) {
        if (variable == null) {
            throw new NullPointerException("variable is marked @NonNull but is null");
        }
        Preconditions.checkState((boolean)this.variables.containsKey(variable), (String)"Cannot associate array with variable \"%s\": variable \"%s\" does not exist in this SameDiff instance", (Object)variable, (Object)variable);
        this.associateArrayWithVariable(arr, this.getVariable(variable));
    }

    public void associateArrayWithVariable(INDArray arr, SDVariable variable) {
        if (variable == null) {
            throw new ND4JIllegalArgumentException("Variable must not be null!");
        }
        if (arr == null) {
            throw new ND4JIllegalArgumentException("Array must not be null");
        }
        if (variable.dataType() != arr.dataType()) {
            arr = arr.castTo(variable.dataType());
        }
        Preconditions.checkState((variable.dataType() == arr.dataType() ? 1 : 0) != 0, (String)"Variable \"%s\" has datatype %s: cannot associate array with type %s with this variable", (Object)variable.getVarName(), (Object)variable.dataType(), (Object)arr.dataType());
        if (this.sessions.get(Thread.currentThread().getId()) == null) {
            this.sessions.put(Thread.currentThread().getId(), new InferenceSession(this));
        }
        boolean duped = false;
        if (arr.isAttached()) {
            arr = arr.detach();
            duped = true;
        }
        if (arr.isView()) {
            arr = arr.dup();
            duped = true;
        }
        if (!duped && variable.getVariableType() == VariableType.VARIABLE) {
            for (DeviceLocalNDArray deviceLocalNDArray : this.variablesArrays.values()) {
                if (deviceLocalNDArray.get() != arr) continue;
                arr = arr.dup();
                break;
            }
        }
        switch (variable.getVariableType()) {
            case VARIABLE: {
                this.variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true));
                break;
            }
            case CONSTANT: {
                this.constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true));
                break;
            }
            case ARRAY: {
                InferenceSession session = this.sessions.get(Thread.currentThread().getId());
                AbstractSession.VarId varId = session.newVarId(variable.getVarName(), "main", 0, null);
                session.getNodeOutputs().put(varId, arr);
                break;
            }
            case PLACEHOLDER: {
                long[] phShape = variable.placeholderShape();
                Preconditions.checkState((phShape == null || Shape.shapeMatchesPlaceholder(phShape, arr.shape()) ? 1 : 0) != 0, (String)"Invalid array shape: cannot associate an array with shape %ndShape with a placeholder of shape %s:shape is wrong rank or does not match on one or more dimensions", (Object)arr, (Object)phShape);
                long tid = Thread.currentThread().getId();
                if (!this.placeholdersPerThread.containsKey(tid)) {
                    this.placeholdersPerThread.put(tid, new HashMap());
                }
                this.placeholdersPerThread.get(tid).put(variable.getVarName(), arr);
                break;
            }
            default: {
                throw new IllegalStateException("Unknown variable type: " + (Object)((Object)variable.getVariableType()));
            }
        }
        if (this.sameDiffFunctionInstances != null && this.sameDiffFunctionInstances.size() > 0) {
            for (Map.Entry entry : this.sameDiffFunctionInstances.entrySet()) {
                SameDiff sd = (SameDiff)entry.getValue();
                SDVariable v = sd.getVariable(variable.getVarName());
                if (v == null) continue;
                sd.associateArrayWithVariable(arr, v);
            }
        }
    }

    public void assignArray(@NonNull INDArray arr, @NonNull SDVariable variable) {
        if (arr == null) {
            throw new NullPointerException("arr is marked @NonNull but is null");
        }
        if (variable == null) {
            throw new NullPointerException("variable is marked @NonNull but is null");
        }
        Preconditions.checkState((variable.getVariableType() == VariableType.VARIABLE || variable.getVariableType() == VariableType.CONSTANT ? 1 : 0) != 0, (String)"assignArray method can only be used with VARIBLE or CONSTANT type SDVariables, variable \"%s\" has type %s", (Object)variable.getVarName(), (Object)((Object)variable.getVariableType()));
        if (arr.isView()) {
            arr = arr.dup();
        }
        if (variable.getVariableType() == VariableType.VARIABLE) {
            this.variablesArrays.get(variable.getVarName()).update(arr);
        } else {
            this.constantArrays.get(variable.getVarName()).update(arr);
        }
    }

    public void putSubFunction(String name, SameDiff nameSpace) {
        if (this.sameDiffFunctionInstances.containsKey(name) && this.sameDiffFunctionInstances.get(name) != nameSpace) {
            throw new ND4JIllegalStateException("Unable to replace samediff namespace. Please choose another opName");
        }
        this.sameDiffFunctionInstances.put(name, nameSpace);
    }

    public Map<String, SDVariable> variableMap() {
        LinkedHashMap<String, SDVariable> ret = new LinkedHashMap<String, SDVariable>();
        for (Variable v : this.variables.values()) {
            ret.put(v.getName(), v.getVariable());
        }
        return ret;
    }

    @Deprecated
    public SDVariable invoke(Op op, SDVariable x, SDVariable y) {
        if (!opMethods.containsKey(op.opName())) {
            throw new ND4JIllegalStateException("Illegal method opName " + op.opName());
        }
        if (x != null && y != null) {
            try {
                return (SDVariable)opMethods.get(op.opName()).invoke((Object)this, x, y);
            }
            catch (Exception exception) {
            }
        } else {
            try {
                return (SDVariable)opMethods.get(op.opName()).invoke((Object)this, x);
            }
            catch (Exception exception) {
                // empty catch block
            }
        }
        throw new ND4JIllegalStateException("Illegal method opName " + op.opName());
    }

    public Collection<String> definedFunctionNames() {
        return this.sameDiffFunctionInstances.keySet();
    }

    public SDVariable invoke(Op op, SDVariable x) {
        return this.invoke(op, x, null);
    }

    private SameDiff() {
        this.functionFactory = new DifferentialFunctionFactory(this);
        this.sameDiffFunctionDefinitionMap = new LinkedHashMap<String, SameDiffFunctionDefinition>();
        this.sameDiffFunctionInstances = new LinkedHashMap<String, SameDiff>();
        this.forwardVarForGrad = new LinkedHashMap<String, SDVariable>();
        this.opsForResult = new IntArrayKeyMap();
        this.variableNameToShape = new LinkedHashMap<String, long[]>();
        this.placeHolderOriginalShapes = new LinkedHashMap<String, long[]>();
        this.placeHolderFunctions = new LinkedHashSet<String>();
        this.baseNameForFunctionInstanceId = new LinkedHashMap<String, String>();
        this.propertiesToResolve = new LinkedHashMap<String, List<String>>();
        this.propertiesForFunction = new LinkedHashMap<String, Map<String, Object>>();
        this.fieldVariableResolutionMapping = HashBasedTable.create();
    }

    public void addPropertyToResolve(DifferentialFunction forFunction, String arrayName) {
        if (!this.propertiesToResolve.containsKey(forFunction.getOwnName())) {
            ArrayList<String> newVal = new ArrayList<String>();
            newVal.add(arrayName);
            this.propertiesToResolve.put(forFunction.getOwnName(), newVal);
        } else {
            List<String> newVal = this.propertiesToResolve.get(forFunction.getOwnName());
            newVal.add(arrayName);
        }
    }

    public void removePropertyToResolve(DifferentialFunction forFunction, String arrayName) {
        if (this.propertiesToResolve.containsKey(forFunction.getOwnName())) {
            List<String> newVal = this.propertiesToResolve.get(forFunction.getOwnName());
            newVal.remove(arrayName);
        }
    }

    public List<String> propertiesToResolveForFunction(DifferentialFunction function) {
        if (!this.propertiesToResolve.containsKey(function.getOwnName())) {
            return Collections.emptyList();
        }
        return this.propertiesToResolve.get(function.getOwnName());
    }

    private void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, Object propertyValue) {
        if (!this.propertiesForFunction.containsKey(functionFor.getOwnName())) {
            LinkedHashMap<String, Object> fields = new LinkedHashMap<String, Object>();
            fields.put(propertyName, propertyValue);
            this.propertiesForFunction.put(functionFor.getOwnName(), fields);
        } else {
            Map<String, Object> fieldMap = this.propertiesForFunction.get(functionFor.getOwnName());
            if (fieldMap.containsKey(propertyName)) {
                throw new ND4JIllegalStateException("Attempting to override property " + propertyName);
            }
            fieldMap.put(propertyName, propertyValue);
        }
    }

    public void addVariableMappingForField(DifferentialFunction function, String fieldName, String varName) {
        this.fieldVariableResolutionMapping.put((Object)function.getOwnName(), (Object)fieldName, (Object)varName);
    }

    public String getVarNameForFieldAndFunction(DifferentialFunction function, String fieldName) {
        return (String)this.fieldVariableResolutionMapping.get((Object)function.getOwnName(), (Object)fieldName);
    }

    public void setBaseNameForFunctionInstanceId(String baseName, DifferentialFunction function) {
        this.baseNameForFunctionInstanceId.put(function.getOwnName(), baseName);
    }

    public String getBaseNameForFunction(DifferentialFunction function) {
        return this.baseNameForFunctionInstanceId.get(function.getOwnName());
    }

    public <X extends SDVariable> X setupFunction(X function) {
        Preconditions.checkNotNull(function, (String)"Passed in function must not be null!");
        if (function instanceof SDVariable) {
            if (function.getSameDiff() != this) {
                function.setSameDiff(this);
            }
            return function;
        }
        return function;
    }

    public void addOutgoingFor(SDVariable[] variables, DifferentialFunction function) {
        String[] varNames = new String[variables.length];
        for (int i = 0; i < varNames.length; ++i) {
            varNames[i] = variables[i].getVarName();
        }
        this.addOutgoingFor(varNames, function);
    }

    public void addOutgoingFor(String[] varNames, DifferentialFunction function) {
        if (function.getOwnName() == null) {
            throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
        }
        if (this.ops.get(function.getOwnName()).getOutputsOfOp() != null && !this.ops.get(function.getOwnName()).getOutputsOfOp().isEmpty()) {
            throw new ND4JIllegalStateException("Outgoing arguments already declared for " + function);
        }
        if (varNames == null) {
            throw new ND4JIllegalStateException("Var names can not be null!");
        }
        for (int i = 0; i < varNames.length; ++i) {
            if (varNames[i] != null) continue;
            throw new ND4JIllegalStateException("Variable name elements can not be null!");
        }
        this.ops.get(function.getOwnName()).setOutputsOfOp(Arrays.asList(varNames));
        for (String resultName : varNames) {
            this.variables.get(resultName).setOutputOfOp(function.getOwnName());
        }
    }

    public void addArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) {
        if (interceptor == null) {
            throw new NullPointerException("interceptor is marked @NonNull but is null");
        }
        this.argumentInterceptors.push(interceptor);
    }

    private boolean isArgumentInterceptorPaused(@NonNull ArgumentInterceptor interceptor) {
        if (interceptor == null) {
            throw new NullPointerException("interceptor is marked @NonNull but is null");
        }
        return this.pausedArgumentInterceptors.contains(interceptor);
    }

    private ArgumentInterceptor getArgumentInterceptorToUse() {
        if (this.argumentInterceptors.isEmpty()) {
            return null;
        }
        ArgumentInterceptor use = this.argumentInterceptors.peek();
        int i = 1;
        while (this.isArgumentInterceptorPaused(use)) {
            if (this.argumentInterceptors.size() - i < 0) {
                return null;
            }
            use = (ArgumentInterceptor)this.argumentInterceptors.elementAt(this.argumentInterceptors.size() - i);
            ++i;
        }
        return use;
    }

    public void removeArgumentInterceptor() {
        if (!this.argumentInterceptors.isEmpty()) {
            this.argumentInterceptors.pop();
        }
    }

    public void pauseArgumentInterceptor() {
        this.pausedArgumentInterceptors.add(this.argumentInterceptors.peek());
    }

    public void pauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) {
        if (interceptor == null) {
            throw new NullPointerException("interceptor is marked @NonNull but is null");
        }
        this.pausedArgumentInterceptors.add(interceptor);
    }

    public void unpauseArgumentInterceptor() {
        this.pausedArgumentInterceptors.remove(this.argumentInterceptors.peek());
    }

    public void unpauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) {
        if (interceptor == null) {
            throw new NullPointerException("interceptor is marked @NonNull but is null");
        }
        this.pausedArgumentInterceptors.remove(interceptor);
    }

    public void addArgsFor(String[] variables, DifferentialFunction function) {
        ArgumentInterceptor interceptor = this.getArgumentInterceptorToUse();
        if (interceptor != null) {
            this.pauseArgumentInterceptor(interceptor);
            for (int i = 0; i < variables.length; ++i) {
                variables[i] = interceptor.intercept(this.getVariable(variables[i])).getVarName();
            }
            this.unpauseArgumentInterceptor(interceptor);
        }
        if (function.getOwnName() == null) {
            throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
        }
        for (String varName : variables) {
            if (!this.isPlaceHolder(varName)) continue;
            this.placeHolderFunctions.add(function.getOwnName());
        }
        if (!this.ops.containsKey(function.getOwnName())) {
            this.ops.put(function.getOwnName(), SameDiffOp.builder().name(function.getOwnName()).op(function).build());
        }
        this.ops.get(function.getOwnName()).setInputsToOp(Arrays.asList(variables));
        for (String variableName : variables) {
            List<String> funcs = this.variables.get(variableName).getInputsForOp();
            if (funcs == null) {
                funcs = new ArrayList<String>();
                this.variables.get(variableName).setInputsForOp(funcs);
            }
            if (funcs.contains(function.getOwnName())) continue;
            funcs.add(function.getOwnName());
        }
    }

    public void addArgsFor(SDVariable[] variables, DifferentialFunction function) {
        String[] varNames = new String[variables.length];
        for (int i = 0; i < varNames.length; ++i) {
            if (variables[i] == null) {
                throw new ND4JIllegalStateException("Found null variable at index " + i);
            }
            varNames[i] = variables[i].getVarName();
        }
        this.addArgsFor(varNames, function);
    }

    public void replaceArgFor(int i, @NonNull SDVariable newArg, @NonNull DifferentialFunction function) {
        List<String> oldFuncs;
        if (newArg == null) {
            throw new NullPointerException("newArg is marked @NonNull but is null");
        }
        if (function == null) {
            throw new NullPointerException("function is marked @NonNull but is null");
        }
        Preconditions.checkArgument((i < function.args().length ? 1 : 0) != 0, (String)("Index out of range: function " + function.getOwnName() + " only has " + function.args().length + " args but you are tryingto replace the argument at " + i));
        String oldName = function.arg(i).getVarName();
        String newName = newArg.getVarName();
        if (function.arg(i).isPlaceHolder() && !newArg.isPlaceHolder()) {
            boolean otherPlaceholders = false;
            for (int j = 0; j < function.argNames().length; ++j) {
                if (j == i || !function.arg(j).isPlaceHolder()) continue;
                otherPlaceholders = true;
            }
            if (!otherPlaceholders) {
                this.placeHolderFunctions.remove(function.getOwnName());
            }
        } else if (!function.arg(i).isPlaceHolder() && newArg.isPlaceHolder() && !this.placeHolderFunctions.contains(function.getOwnName())) {
            this.placeHolderFunctions.add(function.getOwnName());
        }
        List<String> oldArgs = this.ops.get(function.getOwnName()).getInputsToOp();
        oldArgs = new ArrayList<String>(oldArgs);
        oldArgs.set(i, newName);
        this.ops.get(function.getOwnName()).setInputsToOp(oldArgs);
        List<String> funcs = this.variables.get(newName).getInputsForOp();
        if (funcs == null) {
            funcs = new ArrayList<String>();
            this.variables.get(newName).setInputsForOp(funcs);
        }
        if (!funcs.contains(function.getOwnName())) {
            funcs.add(function.getOwnName());
        }
        if ((oldFuncs = this.variables.get(oldName).getInputsForOp()) != null && !ArrayUtils.contains((Object[])function.argNames(), (Object)oldName)) {
            oldFuncs.remove(function.getOwnName());
        }
    }

    public boolean hasArgs(DifferentialFunction function) {
        List<String> vertexIdArgs = this.ops.get(function.getOwnName()).getInputsToOp();
        return vertexIdArgs != null && vertexIdArgs.size() > 0;
    }

    public void clearPlaceholders(boolean allThreads) {
        if (allThreads) {
            this.placeholdersPerThread.clear();
        } else {
            long tid = Thread.currentThread().getId();
            this.placeholdersPerThread.remove(tid);
        }
        for (SameDiff sd : this.sameDiffFunctionInstances.values()) {
            sd.clearPlaceholders(allThreads);
        }
    }

    public void clearOpInputs() {
        for (SameDiffOp op : this.ops.values()) {
            Object o;
            if (op.getOp() instanceof Op) {
                o = (Op)((Object)op.getOp());
                o.setX(null);
                if (o.y() == null) continue;
                o.setY(null);
                continue;
            }
            if (!(op.getOp() instanceof DynamicCustomOp)) continue;
            o = (DynamicCustomOp)op.getOp();
            ((DynamicCustomOp)o).setInputArguments(null);
        }
        for (SameDiff sd : this.sameDiffFunctionInstances.values()) {
            sd.clearOpInputs();
        }
    }

    public DifferentialFunction[] ops() {
        ArrayList<DifferentialFunction> out = new ArrayList<DifferentialFunction>(this.ops.size());
        for (SameDiffOp op : this.ops.values()) {
            out.add(op.getOp());
        }
        return out.toArray(new DifferentialFunction[out.size()]);
    }

    public int hashCode() {
        int result = super.hashCode();
        result = 31 * result + (this.variables != null ? this.variables.hashCode() : 0);
        return result;
    }

    public static SameDiff create(SameDiff originalSameDiff) {
        DifferentialFunctionFactory differentialFunctionFactory;
        SameDiff ret = SameDiff.builder().sameDiffFunctionInstances(originalSameDiff.sameDiffFunctionInstances).build();
        ret.variables.putAll(originalSameDiff.variables);
        ret.functionFactory = differentialFunctionFactory = new DifferentialFunctionFactory(ret);
        return ret;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        SameDiff sameDiff = (SameDiff)o;
        if (this.variables != null ? !this.variables.equals(sameDiff.variables) : sameDiff.variables != null) {
            return false;
        }
        if (this.sameDiffFunctionDefinitionMap != null ? !this.sameDiffFunctionDefinitionMap.equals(sameDiff.sameDiffFunctionDefinitionMap) : sameDiff.sameDiffFunctionDefinitionMap != null) {
            return false;
        }
        return this.sameDiffFunctionInstances != null ? this.sameDiffFunctionInstances.equals(sameDiff.sameDiffFunctionInstances) : sameDiff.sameDiffFunctionInstances == null;
    }

    public static SameDiff create() {
        return new SameDiff();
    }

    public SameDiff dup() {
        ByteBuffer bb = this.asFlatBuffers(true);
        try {
            return SameDiff.fromFlatBuffers(bb);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public long numElements() {
        long ret = 0L;
        for (SDVariable variable : this.variables()) {
            long[] shape = variable.getShape();
            if (shape == null) continue;
            ret += (long)ArrayUtil.prod((long[])shape);
        }
        return ret;
    }

    public List<String> inputs() {
        ArrayList<String> out = new ArrayList<String>();
        for (String s : this.variables.keySet()) {
            if (!this.isPlaceHolder(s)) continue;
            out.add(s);
        }
        return out;
    }

    public List<String> outputs() {
        ArrayList<String> out = new ArrayList<String>();
        for (Variable v : this.variables.values()) {
            String opName;
            SameDiffOp o;
            if (v.getVariable().isConstant() || v.getVariable().isPlaceHolder() || v.getInputsForOp() != null && !v.getInputsForOp().isEmpty() || v.getControlDepsForOp() != null && !v.getControlDepsForOp().isEmpty() || v.getControlDepsForVar() != null && !v.getControlDepsForVar().isEmpty() || v.getOutputOfOp() != null && ((o = this.ops.get(opName = v.getOutputOfOp())).getOp() instanceof Assert || o.getOp() instanceof Switch)) continue;
            out.add(v.getName());
        }
        return out;
    }

    public List<SDVariable> variables() {
        return new ArrayList<SDVariable>(this.variableMap().values());
    }

    public List<String> getLossVariables() {
        return Collections.unmodifiableList(this.lossVariables);
    }

    public void setLossVariables(String ... lossVariableNames) {
        if (lossVariableNames == null) {
            throw new NullPointerException("lossVariableNames is marked @NonNull but is null");
        }
        this.lossVariables.clear();
        for (String s : lossVariableNames) {
            this.addLossVariable(s);
        }
        this.sameDiffFunctionInstances.remove(GRAD_FN_KEY);
    }

    public void setLossVariables(SDVariable ... lossVariables) {
        if (lossVariables == null) {
            throw new NullPointerException("lossVariables is marked @NonNull but is null");
        }
        String[] varNames = new String[lossVariables.length];
        for (int i = 0; i < lossVariables.length; ++i) {
            varNames[i] = lossVariables[i].getVarName();
        }
        this.setLossVariables(varNames);
    }

    public void addLossVariable(@NonNull String variableName) {
        if (variableName == null) {
            throw new NullPointerException("variableName is marked @NonNull but is null");
        }
        Preconditions.checkState((boolean)this.hasVariable(variableName), (String)"No variable with name \"%s\" exists", (Object)variableName);
        SDVariable v = this.getVariable(variableName);
        Preconditions.checkState((boolean)v.dataType().isFPType(), (String)"Only floating point type variables can be marked as losses to be minimized. SDVariable \"%s\" has datatype %s", (Object)variableName, (Object)v.dataType());
        Preconditions.checkState((v.getVariableType() == VariableType.ARRAY ? 1 : 0) != 0, (String)"Only ARRAY type SDVariables can be marked as losses to be minimized. SDVariable \"%s\" has variable type %s", (Object)variableName, (Object)((Object)v.getVariableType()));
        if (!this.lossVariables.contains(variableName)) {
            this.lossVariables.add(variableName);
        }
    }

    public void addLossVariable(@NonNull SDVariable variable) {
        if (variable == null) {
            throw new NullPointerException("variable is marked @NonNull but is null");
        }
        this.addLossVariable(variable.getVarName());
    }

    public void setTrainingConfig(TrainingConfig trainingConfig) {
        this.trainingConfig = trainingConfig;
    }

    public History fit(@NonNull DataSet dataSet, Listener ... listeners) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked @NonNull but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        return this.fit((MultiDataSetIterator)new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), 1, false, null, 1, listeners);
    }

    public History fit(@NonNull MultiDataSet dataSet, Listener ... listeners) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked @NonNull but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        return this.fit((MultiDataSetIterator)new SingletonMultiDataSetIterator(dataSet), 1, false, null, 1, listeners);
    }

    public History fit(@NonNull DataSetIterator iter, int numEpochs, DataSetIterator validationIter, int validationFrequency, Listener ... listeners) {
        if (iter == null) {
            throw new NullPointerException("iter is marked @NonNull but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        return this.fit().train(iter, numEpochs).validate(validationIter, validationFrequency).listeners(listeners).exec();
    }

    public History fit(@NonNull DataSetIterator iter, int numEpochs, Listener ... listeners) {
        if (iter == null) {
            throw new NullPointerException("iter is marked @NonNull but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        return this.fit().train(iter, numEpochs).listeners(listeners).exec();
    }

    public History fit(@NonNull MultiDataSetIterator iter, int numEpochs, MultiDataSetIterator validationIter, int validationFrequency, Listener ... listeners) {
        if (iter == null) {
            throw new NullPointerException("iter is marked @NonNull but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        return this.fit(iter, numEpochs, true, validationIter, validationFrequency, listeners);
    }

    public History fit(@NonNull MultiDataSetIterator iter, int numEpochs, Listener ... listeners) {
        if (iter == null) {
            throw new NullPointerException("iter is marked @NonNull but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        return this.fit().train(iter, numEpochs).listeners(listeners).exec();
    }

    public FitConfig fit() {
        return new FitConfig(this);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected synchronized History fit(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount, MultiDataSetIterator validationData, int validationFrequency, Listener ... listeners) {
        if (iter == null) {
            throw new NullPointerException("iter is marked @NonNull but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        boolean async = iter.asyncSupported();
        boolean validationAsync = false;
        if (validationData != null) {
            validationAsync = validationData.asyncSupported();
        }
        if (async) {
            iter = new AsyncMultiDataSetIterator(iter, 3, true);
        }
        if (validationAsync) {
            validationData = new AsyncMultiDataSetIterator(validationData, 3, true);
        }
        try {
            History history = this.fitHelper(iter, numEpochs, incrementEpochCount, validationData, validationFrequency, Arrays.asList(listeners));
            return history;
        }
        finally {
            if (async) {
                ((AsyncMultiDataSetIterator)iter).shutdown();
            }
            if (validationAsync) {
                ((AsyncMultiDataSetIterator)validationData).shutdown();
            }
        }
    }

    protected synchronized History fitHelper(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount, MultiDataSetIterator validationData, int validationFrequency, @NonNull List<Listener> listeners) {
        if (iter == null) {
            throw new NullPointerException("iter is marked @NonNull but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        Preconditions.checkNotNull((Object)iter, (String)"Iterator must not be null");
        Preconditions.checkState((numEpochs > 0 ? 1 : 0) != 0, (String)"Number of training epochs must be a positive number. Got: %s", (int)numEpochs);
        Preconditions.checkState((this.trainingConfig != null ? 1 : 0) != 0, (String)"No training configuration has been set. A training configuration must be set before training. Use setTrainingConfig(TrainingConfig)");
        Preconditions.checkState((numEpochs == 1 || iter.resetSupported() ? 1 : 0) != 0, (String)"Cannot train for multiple epochs on an iterator that does not support resetting");
        HistoryListener history = new HistoryListener(this.trainingConfig);
        ArrayList<Listener> activeListeners = new ArrayList<Listener>();
        if (!history.evaluations().isEmpty()) {
            activeListeners.add(history);
        }
        for (Listener l : this.listeners) {
            if (!l.isActive(Operation.TRAINING)) continue;
            activeListeners.add(l);
        }
        for (Listener l : listeners) {
            if (!l.isActive(Operation.TRAINING)) continue;
            activeListeners.add(l);
        }
        this.validateListenerActivations(activeListeners, Operation.TRAINING);
        this.validateListenerActivations(activeListeners, Operation.TRAINING_VALIDATION);
        if (!iter.hasNext() && iter.resetSupported()) {
            iter.reset();
        }
        boolean performedValidation = false;
        int trainThreadNum = 0;
        long jThreadId = Thread.currentThread().getId();
        boolean hasListeners = !activeListeners.isEmpty();
        At at = At.builder().epoch(this.trainingConfig.getEpochCount()).iteration(this.trainingConfig.getIterationCount()).trainingThreadNum(trainThreadNum).javaThreadNum(jThreadId).operation(Operation.TRAINING).build();
        LossCurve lossCurve = null;
        HashSet<String> requiredVars = new HashSet<String>();
        for (Listener l : activeListeners) {
            requiredVars.addAll(l.requiredVariables(this).trainingVariables());
        }
        for (int i = 0; i < numEpochs; ++i) {
            if (incrementEpochCount && hasListeners) {
                at.setEpoch(this.trainingConfig.getEpochCount());
                for (Listener l : activeListeners) {
                    l.epochStart(this, at);
                }
            }
            long epochStartTime = System.currentTimeMillis();
            double[] lossSums = null;
            List<String> lossNames = null;
            int lossCount = 0;
            while (iter.hasNext()) {
                Map<String, INDArray> placeholders;
                long dataEnd;
                long dataStart = hasListeners ? System.currentTimeMillis() : 0L;
                MultiDataSet ds = (MultiDataSet)iter.next();
                long l = dataEnd = hasListeners ? System.currentTimeMillis() : 0L;
                if (!performedValidation) {
                    Preconditions.checkState((this.trainingConfig.getDataSetFeatureMapping().size() == ds.numFeatureArrays() ? 1 : 0) != 0, (String)"The number of dataset feature mapping variables set in the training configuration (%s) must match the number of dataset feature arrays (%s)", (int)this.trainingConfig.getDataSetFeatureMapping().size(), (int)ds.numFeatureArrays());
                    List<String> labelMapping = this.trainingConfig.getDataSetLabelMapping();
                    int lblSize = labelMapping == null ? 0 : labelMapping.size();
                    Preconditions.checkState((lblSize == ds.numLabelsArrays() ? 1 : 0) != 0, (String)"The number of dataset label mapping variables set in the training configuration (%s) must match the number of dataset label arrays (%s)", (int)lblSize, (int)ds.numLabelsArrays());
                    performedValidation = true;
                }
                if (hasListeners) {
                    at.setIteration(this.trainingConfig.getIterationCount());
                    for (Listener l2 : activeListeners) {
                        l2.iterationStart(this, at, ds, dataEnd - dataStart);
                    }
                }
                Preconditions.checkState(((placeholders = this.toPlaceholderMap(ds)).size() > 0 ? 1 : 0) != 0, (String)"No placeholder variables were set for training");
                this.resolveVariablesWith(placeholders);
                this.execBackwards(placeholders, at.operation(), ds, requiredVars, activeListeners);
                if (!this.initializedTraining) {
                    this.initializeTraining();
                }
                HashMap regScore = null;
                if (hasListeners) {
                    regScore = new HashMap();
                }
                int iteration = this.trainingConfig.getIterationCount();
                int e = this.trainingConfig.getEpochCount();
                for (Variable v : this.variables.values()) {
                    double lr;
                    SDVariable sdv = v.getVariable();
                    if (sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType()) continue;
                    INDArray param = sdv.getArr();
                    SDVariable sDVariable = sdv.getGradient();
                    if (sDVariable == null) continue;
                    INDArray grad = sDVariable.getArr();
                    List<Regularization> r = this.trainingConfig.getRegularization();
                    int iterCount = this.trainingConfig.getIterationCount();
                    int epochCount = this.trainingConfig.getEpochCount();
                    double d = lr = this.trainingConfig.getUpdater().hasLearningRate() ? this.trainingConfig.getUpdater().getLearningRate(iteration, epochCount) : 1.0;
                    if (r != null && r.size() > 0) {
                        for (Regularization reg : r) {
                            if (reg.applyStep() != Regularization.ApplyStep.BEFORE_UPDATER) continue;
                            reg.apply(param, grad, lr, iterCount, epochCount);
                        }
                    }
                    INDArray reshapedView = Shape.newShapeNoCopy(grad, new long[]{1L, grad.length()}, grad.ordering() == 'f');
                    Preconditions.checkState((reshapedView != null ? 1 : 0) != 0, (String)"Error reshaping array for parameter \"%s\": array is a view?", (Object)sdv);
                    GradientUpdater u = this.updaterMap.get(sdv.getVarName());
                    try {
                        u.applyUpdater(reshapedView, iteration, e);
                    }
                    catch (Throwable t) {
                        throw new RuntimeException("Error applying updater " + u.getClass().getSimpleName() + " to parameter \"" + sdv.getVarName() + "\": either parameter size is inconsistent between iterations, or \"" + sdv.getVarName() + "\" should not be a trainable parameter?", t);
                    }
                    if (r != null && r.size() > 0) {
                        for (Regularization reg : r) {
                            if (reg.applyStep() != Regularization.ApplyStep.POST_UPDATER) continue;
                            reg.apply(param, grad, lr, iterCount, epochCount);
                            if (!hasListeners) continue;
                            double score = reg.score(param, iterCount, epochCount);
                            if (!regScore.containsKey(reg.getClass())) {
                                regScore.put(reg.getClass(), new AtomicDouble());
                            }
                            ((AtomicDouble)regScore.get(reg.getClass())).addAndGet(score);
                        }
                    }
                    if (hasListeners) {
                        for (Listener l3 : activeListeners) {
                            if (!l3.isActive(at.operation())) continue;
                            l3.preUpdate(this, at, v, reshapedView);
                        }
                    }
                    if (this.trainingConfig.isMinimize()) {
                        param.subi(grad);
                        continue;
                    }
                    param.addi(grad);
                }
                if (hasListeners) {
                    List<String> lossVars;
                    double[] d = new double[this.lossVariables.size() + regScore.size()];
                    if (regScore.size() > 0) {
                        lossVars = new ArrayList<String>(this.lossVariables.size() + regScore.size());
                        lossVars.addAll(this.lossVariables);
                        int s = regScore.size();
                        for (Map.Entry entry : regScore.entrySet()) {
                            lossVars.add(((Class)entry.getKey()).getSimpleName());
                            d[s] = ((AtomicDouble)entry.getValue()).get();
                        }
                    } else {
                        lossVars = this.lossVariables;
                    }
                    SameDiff gradFn = this.sameDiffFunctionInstances.get(GRAD_FN_KEY);
                    int count = 0;
                    for (String s : this.lossVariables) {
                        INDArray arr = gradFn.getArrForVarName(s);
                        double l4 = arr.isScalar() ? arr.getDouble(0L) : arr.sumNumber().doubleValue();
                        d[count++] = l4;
                    }
                    Loss loss = new Loss(lossVars, d);
                    if (lossNames == null) {
                        lossNames = lossVars;
                    } else {
                        Preconditions.checkState((boolean)lossNames.equals(lossVars), (String)"Loss names mismatch, expected: %s, got: %s", lossNames, lossVars);
                    }
                    if (lossSums == null) {
                        lossSums = d;
                    } else {
                        Preconditions.checkState((boolean)lossNames.equals(lossVars), (String)"Loss size mismatch, expected: %s, got: %s", (int)lossSums.length, (int)d.length);
                        for (int j = 0; j < lossSums.length; ++j) {
                            int n = j;
                            lossSums[n] = lossSums[n] + d[j];
                        }
                    }
                    ++lossCount;
                    for (Listener l5 : activeListeners) {
                        l5.iterationDone(this, at, ds, loss);
                    }
                }
                this.trainingConfig.incrementIterationCount();
            }
            long epochTime = System.currentTimeMillis() - epochStartTime;
            if (incrementEpochCount && hasListeners) {
                int j = 0;
                while (j < lossSums.length) {
                    int n = j++;
                    lossSums[n] = lossSums[n] / (double)lossCount;
                }
                lossCurve = lossCurve != null ? lossCurve.addLossAndCopy(lossSums, lossNames) : new LossCurve(lossSums, lossNames);
            }
            if (incrementEpochCount) {
                if (hasListeners) {
                    boolean doStop = false;
                    Listener stopped = null;
                    for (Listener l : activeListeners) {
                        ListenerResponse res = l.epochEnd(this, at, lossCurve, epochTime);
                        if (res != ListenerResponse.STOP || i >= numEpochs - 1) continue;
                        doStop = true;
                        stopped = l;
                    }
                    if (doStop) {
                        log.info("Stopping training early.  Listener " + stopped + " gave a STOP signal at epoch " + at.epoch() + " and iteration " + at.iteration());
                        for (Listener l1 : activeListeners) {
                            l1.operationEnd(this, Operation.TRAINING);
                        }
                        if (i < numEpochs - 1) {
                            iter.reset();
                        }
                        if (incrementEpochCount) {
                            this.trainingConfig.incrementEpochCount();
                        }
                        return history.getReport();
                    }
                    if (validationData != null && (validationFrequency <= 0 || i % validationFrequency == 0)) {
                        long validationStart = System.currentTimeMillis();
                        this.outputHelper(validationData, new At(at.epoch(), 0, 0, 0L, Operation.TRAINING_VALIDATION), listeners, new String[0]);
                        long validationTime = System.currentTimeMillis() - validationStart;
                        boolean doStopV = false;
                        Listener stoppedV = null;
                        for (Listener l : activeListeners) {
                            ListenerResponse res = l.validationDone(this, at, validationTime);
                            if (res != ListenerResponse.STOP || i >= numEpochs - 1) continue;
                            doStopV = true;
                            stoppedV = l;
                        }
                        if (doStopV) {
                            log.info("Stopping training early from validation.  Listener " + stoppedV + " gave a STOP signal at epoch " + at.epoch() + " and iteration " + at.iteration());
                            for (Listener l1 : activeListeners) {
                                l1.operationEnd(this, Operation.TRAINING);
                            }
                            if (i < numEpochs - 1) {
                                iter.reset();
                            }
                            if (incrementEpochCount) {
                                this.trainingConfig.incrementEpochCount();
                            }
                            return history.getReport();
                        }
                    }
                }
                this.trainingConfig.incrementEpochCount();
            }
            if (i >= numEpochs - 1) continue;
            iter.reset();
        }
        for (Listener l1 : activeListeners) {
            l1.operationEnd(this, Operation.TRAINING);
        }
        return history.getReport();
    }

    private void validateListenerActivations(List<Listener> listeners, Operation op) {
        for (Listener l : listeners) {
            for (String s : l.requiredVariables(this).requiredVariables(op)) {
                if (this.variables.containsKey(s)) continue;
                Preconditions.checkState((boolean)false, (String)"Listener %s requested variable %s that is not defined in this SameDiff graph", (Object)l, (Object)s);
            }
        }
    }

    public double calcRegularizationScore() {
        Preconditions.checkState((this.trainingConfig != null ? 1 : 0) != 0, (String)"No training configuration has been set. A training configuration must be set before calculating the L2 loss. Use setTrainingConfig(TrainingConfig)");
        if (this.trainingConfig.getRegularization() == null || this.trainingConfig.getRegularization().isEmpty()) {
            return 0.0;
        }
        List<Regularization> l = this.trainingConfig.getRegularization();
        double loss = 0.0;
        for (Variable v : this.variables.values()) {
            SDVariable sdv = v.getVariable();
            if (sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType()) continue;
            for (Regularization r : l) {
                INDArray arr = sdv.getArr();
                loss += r.score(arr, this.trainingConfig.getIterationCount(), this.trainingConfig.getEpochCount());
            }
        }
        return loss;
    }

    protected void initializeTraining() {
        if (!this.initializedTraining) {
            if (this.trainingConfig == null) {
                throw new ND4JIllegalStateException("Please specify a training config with setTrainingConfig");
            }
            this.updaterMap = new HashMap<String, GradientUpdater>();
            for (Variable v : this.variables.values()) {
                if (v.getVariable().getVariableType() != VariableType.VARIABLE || !v.getVariable().dataType().isFPType()) continue;
                INDArray arr = v.getVariable().getArr();
                long stateSize = this.trainingConfig.getUpdater().stateSize(arr.length());
                INDArray view = stateSize == 0L ? null : Nd4j.createUninitialized(arr.dataType(), 1L, stateSize);
                this.updaterMap.put(v.getName(), this.trainingConfig.getUpdater().instantiate(view, true));
            }
            this.initializedTraining = true;
        }
    }

    private Map<String, INDArray> toPlaceholderMap(MultiDataSet ds) {
        HashMap<String, INDArray> placeholders = new HashMap<String, INDArray>();
        int count = 0;
        for (String s : this.trainingConfig.getDataSetFeatureMapping()) {
            placeholders.put(s, ds.getFeatures(count++));
        }
        count = 0;
        if (this.trainingConfig.getDataSetLabelMapping() != null) {
            for (String s : this.trainingConfig.getDataSetLabelMapping()) {
                placeholders.put(s, ds.getLabels(count++));
            }
        }
        if (this.trainingConfig.getDataSetFeatureMaskMapping() != null && this.trainingConfig.getDataSetFeatureMaskMapping().size() > 0) {
            count = 0;
            for (String s : this.trainingConfig.getDataSetFeatureMaskMapping()) {
                if (s == null) {
                    ++count;
                    continue;
                }
                placeholders.put(s, ds.getFeaturesMaskArray(count++));
            }
        }
        if (this.trainingConfig.getDataSetLabelMaskMapping() != null && this.trainingConfig.getDataSetLabelMaskMapping().size() > 0) {
            count = 0;
            for (String s : this.trainingConfig.getDataSetLabelMaskMapping()) {
                if (s == null) {
                    ++count;
                    continue;
                }
                placeholders.put(s, ds.getLabelsMaskArray(count++));
            }
        }
        return placeholders;
    }

    public void evaluate(@NonNull DataSetIterator iterator, @NonNull String outputVariable, @NonNull List<Listener> listeners, IEvaluation ... evaluations) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (outputVariable == null) {
            throw new NullPointerException("outputVariable is marked @NonNull but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        if (evaluations == null) {
            throw new NullPointerException("evaluations is marked @NonNull but is null");
        }
        Preconditions.checkArgument((evaluations != null && evaluations.length > 0 ? 1 : 0) != 0, (String)"No evaluations were passed to the evaluate method");
        this.evaluate().data(iterator).evaluate(outputVariable, evaluations).listeners(listeners.toArray(new Listener[0])).exec();
    }

    public void evaluate(@NonNull DataSetIterator iterator, @NonNull String outputVariable, IEvaluation ... evaluations) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (outputVariable == null) {
            throw new NullPointerException("outputVariable is marked @NonNull but is null");
        }
        if (evaluations == null) {
            throw new NullPointerException("evaluations is marked @NonNull but is null");
        }
        this.evaluate().data(iterator).evaluate(outputVariable, evaluations).exec();
    }

    public void evaluate(@NonNull DataSetIterator iterator, @NonNull Map<String, IEvaluation> variableEvals, Listener ... listeners) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (variableEvals == null) {
            throw new NullPointerException("variableEvals is marked @NonNull but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        HashMap<String, Integer> map = new HashMap<String, Integer>();
        HashMap<String, List<IEvaluation>> variableEvalsList = new HashMap<String, List<IEvaluation>>();
        for (String s : variableEvals.keySet()) {
            map.put(s, 0);
            variableEvalsList.put(s, Collections.singletonList(variableEvals.get(s)));
        }
        this.evaluate((MultiDataSetIterator)new MultiDataSetIteratorAdapter(iterator), variableEvalsList, map, listeners);
    }

    public void evaluateMultiple(DataSetIterator iterator, Map<String, List<IEvaluation>> variableEvals, Listener ... listeners) {
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        HashMap<String, Integer> map = new HashMap<String, Integer>();
        for (String s : variableEvals.keySet()) {
            map.put(s, 0);
        }
        this.evaluate((MultiDataSetIterator)new MultiDataSetIteratorAdapter(iterator), variableEvals, map, listeners);
    }

    public void evaluate(@NonNull MultiDataSetIterator iterator, @NonNull String outputVariable, int labelIndex, @NonNull List<Listener> listeners, IEvaluation ... evaluations) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (outputVariable == null) {
            throw new NullPointerException("outputVariable is marked @NonNull but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        if (evaluations == null) {
            throw new NullPointerException("evaluations is marked @NonNull but is null");
        }
        Preconditions.checkArgument((evaluations != null && evaluations.length > 0 ? 1 : 0) != 0, (String)"No evaluations were passed to the evaluate method");
        this.evaluate().data(iterator).evaluate(outputVariable, labelIndex, evaluations).listeners(listeners.toArray(new Listener[0])).exec();
    }

    public void evaluate(@NonNull MultiDataSetIterator iterator, @NonNull String outputVariable, int labelIndex, IEvaluation ... evaluations) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (outputVariable == null) {
            throw new NullPointerException("outputVariable is marked @NonNull but is null");
        }
        if (evaluations == null) {
            throw new NullPointerException("evaluations is marked @NonNull but is null");
        }
        this.evaluate().data(iterator).evaluate(outputVariable, labelIndex, evaluations).exec();
    }

    public void evaluate(MultiDataSetIterator iterator, Map<String, List<IEvaluation>> variableEvals, Map<String, Integer> predictionLabelMapping, Listener ... listeners) {
        this.evaluateHelper(iterator, variableEvals, predictionLabelMapping, At.defaultAt(Operation.EVALUATION), listeners);
    }

    public EvaluationConfig evaluate() {
        return new EvaluationConfig(this);
    }

    private void evaluateHelper(MultiDataSetIterator iterator, Map<String, List<IEvaluation>> variableEvals, Map<String, Integer> predictionLabelMapping, At at, Listener ... listeners) {
        boolean hasListeners;
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        Preconditions.checkState((this.trainingConfig != null ? 1 : 0) != 0, (String)"Training config has not been set");
        Preconditions.checkState((boolean)variableEvals.keySet().equals(predictionLabelMapping.keySet()), (String)"Keysets for variable evaluations and for the prediction label mapping must be equal. Keys for variables to evaluate: %s vs. keys for label mapping: %s", variableEvals.keySet(), predictionLabelMapping.keySet());
        ArrayList<Listener> activeListeners = new ArrayList<Listener>();
        for (Listener l : listeners) {
            if (!l.isActive(at.operation())) continue;
            activeListeners.add(l);
        }
        for (Listener l : this.listeners) {
            if (!l.isActive(at.operation())) continue;
            activeListeners.add(l);
        }
        this.validateListenerActivations(activeListeners, at.operation());
        for (Listener l : activeListeners) {
            l.operationStart(this, at.operation());
        }
        boolean bl = hasListeners = !activeListeners.isEmpty();
        if (!iterator.hasNext() && iterator.resetSupported()) {
            iterator.reset();
        }
        HashSet<String> requiredVars = new HashSet<String>(variableEvals.keySet());
        if (hasListeners) {
            for (Listener l : activeListeners) {
                requiredVars.addAll(l.requiredVariables(this).evaluationVariables());
            }
        }
        String[] requiredVarsArr = requiredVars.toArray(new String[0]);
        while (iterator.hasNext()) {
            Map<String, INDArray> m;
            long dataStart = hasListeners ? System.currentTimeMillis() : 0L;
            MultiDataSet ds = (MultiDataSet)iterator.next();
            long dataEnd = hasListeners ? System.currentTimeMillis() : 0L;
            Map<String, INDArray> placeholderMap = this.toPlaceholderMap(ds);
            Map outs = null;
            if (hasListeners) {
                for (Listener listener : activeListeners) {
                    listener.iterationStart(this, at, ds, dataEnd - dataStart);
                }
                m = this.directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr);
            } else {
                m = this.directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr);
            }
            for (Map.Entry entry : variableEvals.entrySet()) {
                INDArray prediction = m.get(entry.getKey());
                for (IEvaluation eval : (List)entry.getValue()) {
                    INDArray label = ds.getLabels(predictionLabelMapping.get(entry.getKey()));
                    INDArray mask = ds.getLabelsMaskArray(predictionLabelMapping.get(entry.getKey()));
                    eval.eval(label, prediction, mask);
                }
            }
            if (hasListeners) {
                for (Listener listener : activeListeners) {
                    HashMap outVars = Maps.newHashMap((Map)Maps.filterKeys(outs, (Predicate)Predicates.in(listener.requiredVariables(this).evaluationVariables())));
                    listener.iterationDone(this, at, ds, null);
                }
            }
            at.setIteration(at.iteration() + 1);
        }
        for (Listener l : activeListeners) {
            l.operationEnd(this, at.operation());
        }
    }

    public Map<String, INDArray> output(@NonNull DataSet dataSet, String ... outputs) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked @NonNull but is null");
        }
        if (outputs == null) {
            throw new NullPointerException("outputs is marked @NonNull but is null");
        }
        return this.outputBatches((MultiDataSetIterator)new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0);
    }

    public Map<String, INDArray> output(@NonNull MultiDataSet dataSet, String ... outputs) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked @NonNull but is null");
        }
        if (outputs == null) {
            throw new NullPointerException("outputs is marked @NonNull but is null");
        }
        return this.outputBatches((MultiDataSetIterator)new SingletonMultiDataSetIterator(dataSet), outputs).get(0);
    }

    public Map<String, INDArray> output(@NonNull DataSetIterator iterator, @NonNull List<Listener> listeners, String ... outputs) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        if (outputs == null) {
            throw new NullPointerException("outputs is marked @NonNull but is null");
        }
        return this.output().data(iterator).output(outputs).listeners(listeners.toArray(new Listener[0])).exec();
    }

    public Map<String, INDArray> output(@NonNull DataSetIterator dataSet, String ... outputs) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked @NonNull but is null");
        }
        if (outputs == null) {
            throw new NullPointerException("outputs is marked @NonNull but is null");
        }
        return this.output().data(dataSet).output(outputs).exec();
    }

    public List<Map<String, INDArray>> outputBatches(DataSetIterator iterator, List<Listener> listeners, String ... outputs) {
        return this.output().data(iterator).output(outputs).listeners(listeners.toArray(new Listener[0])).execBatches();
    }

    public List<Map<String, INDArray>> outputBatches(DataSetIterator iterator, String ... outputs) {
        return this.output().data(iterator).output(outputs).execBatches();
    }

    public Map<String, INDArray> output(@NonNull MultiDataSetIterator iterator, @NonNull List<Listener> listeners, String ... outputs) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        if (outputs == null) {
            throw new NullPointerException("outputs is marked @NonNull but is null");
        }
        return TrainingUtils.stackOutputs(this.outputHelper(iterator, At.defaultAt(Operation.INFERENCE), listeners, outputs));
    }

    public Map<String, INDArray> output(@NonNull MultiDataSetIterator dataSet, String ... outputs) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked @NonNull but is null");
        }
        if (outputs == null) {
            throw new NullPointerException("outputs is marked @NonNull but is null");
        }
        return this.output().data(dataSet).output(outputs).exec();
    }

    public List<Map<String, INDArray>> outputBatches(MultiDataSetIterator iterator, List<Listener> listeners, String ... outputs) {
        return this.outputHelper(iterator, At.defaultAt(Operation.INFERENCE), listeners, outputs);
    }

    public List<Map<String, INDArray>> outputBatches(MultiDataSetIterator iterator, String ... outputs) {
        return this.output().data(iterator).output(outputs).execBatches();
    }

    public OutputConfig output() {
        return new OutputConfig(this);
    }

    private List<Map<String, INDArray>> outputHelper(MultiDataSetIterator iterator, At at, @NonNull List<Listener> listeners, String ... outputs) {
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        if (outputs == null) {
            throw new NullPointerException("outputs is marked @NonNull but is null");
        }
        Preconditions.checkState((this.trainingConfig != null ? 1 : 0) != 0, (String)"Training config has not been set");
        ArrayList<Listener> activeListeners = new ArrayList<Listener>();
        for (Listener l : listeners) {
            if (!l.isActive(at.operation())) continue;
            activeListeners.add(l);
        }
        for (Listener l : this.listeners) {
            if (!l.isActive(at.operation())) continue;
            activeListeners.add(l);
        }
        this.validateListenerActivations(activeListeners, at.operation());
        for (Listener l : activeListeners) {
            l.operationStart(this, at.operation());
        }
        boolean hasListeners = !activeListeners.isEmpty();
        List<String> neededOutputs = outputs != null ? Arrays.asList(outputs) : this.outputs();
        String[] neededOutputsArr = neededOutputs.toArray(new String[0]);
        ArrayList<Map<String, INDArray>> predictions = new ArrayList<Map<String, INDArray>>();
        if (!iterator.hasNext() && iterator.resetSupported()) {
            iterator.reset();
        }
        HashSet<String> requiredVars = new HashSet<String>();
        for (Listener l : activeListeners) {
            if (at.operation() == Operation.TRAINING_VALIDATION) {
                requiredVars.addAll(l.requiredVariables(this).validationVariables());
                continue;
            }
            requiredVars.addAll(l.requiredVariables(this).inferenceVariables());
        }
        while (iterator.hasNext()) {
            long dataStart = hasListeners ? System.currentTimeMillis() : 0L;
            MultiDataSet ds = (MultiDataSet)iterator.next();
            long dataEnd = hasListeners ? System.currentTimeMillis() : 0L;
            Map<String, INDArray> placeholderMap = this.toPlaceholderMap(ds);
            if (hasListeners) {
                for (Listener l : activeListeners) {
                    l.iterationStart(this, at, ds, dataEnd - dataStart);
                }
                Map<String, INDArray> outs = this.directExecHelper(placeholderMap, at, ds, requiredVars, activeListeners, neededOutputsArr);
                for (Listener l : activeListeners) {
                    l.iterationDone(this, at, ds, null);
                }
                predictions.add(outs);
            } else {
                predictions.add(this.directExecHelper(placeholderMap, at, ds, requiredVars, activeListeners, neededOutputsArr));
            }
            at.setIteration(at.iteration() + 1);
        }
        for (Listener l : activeListeners) {
            l.operationEnd(this, at.operation());
        }
        return predictions;
    }

    public BatchOutputConfig batchOutput() {
        return new BatchOutputConfig(this);
    }

    @Deprecated
    public Map<String, INDArray> execAll(Map<String, INDArray> placeholders) {
        return this.outputAll(placeholders);
    }

    public Map<String, INDArray> outputAll(Map<String, INDArray> placeholders) {
        return this.batchOutput().outputAll().inputs(placeholders).exec();
    }

    @Deprecated
    public INDArray execSingle(Map<String, INDArray> placeholders, String output) {
        return this.outputSingle(placeholders, output);
    }

    public INDArray outputSingle(Map<String, INDArray> placeholders, String output) {
        return this.batchOutput().output(output).inputs(placeholders).execSingle();
    }

    @Deprecated
    public Map<String, INDArray> exec(Map<String, INDArray> placeholders, List<String> outputs) {
        return this.output(placeholders, outputs);
    }

    public Map<String, INDArray> output(Map<String, INDArray> placeholders, List<String> outputs) {
        return this.batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).exec();
    }

    @Deprecated
    public Map<String, INDArray> exec(Map<String, INDArray> placeholders, String ... outputs) {
        return this.output(placeholders, outputs);
    }

    public Map<String, INDArray> output(Map<String, INDArray> placeholders, String ... outputs) {
        return this.batchOutput().output(outputs).inputs(placeholders).exec();
    }

    public Map<String, INDArray> output(Map<String, INDArray> placeholders, @NonNull List<Listener> listeners, String ... outputs) {
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        return this.batchOutputHelper(placeholders, listeners, outputs);
    }

    protected Map<String, INDArray> batchOutputHelper(Map<String, INDArray> placeholders, @NonNull List<Listener> listeners, String ... outputs) {
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        ArrayList<Listener> activeListeners = new ArrayList<Listener>();
        for (Listener l : this.listeners) {
            if (!l.isActive(Operation.INFERENCE)) continue;
            activeListeners.add(l);
        }
        for (Listener l : listeners) {
            if (!l.isActive(Operation.INFERENCE)) continue;
            activeListeners.add(l);
        }
        for (Listener l : activeListeners) {
            l.operationStart(this, Operation.INFERENCE);
        }
        this.validateListenerActivations(activeListeners, Operation.INFERENCE);
        Map<String, INDArray> ret = this.directExecHelper(placeholders, At.defaultAt(Operation.INFERENCE), null, Collections.emptyList(), activeListeners, outputs);
        for (Listener l : activeListeners) {
            l.operationEnd(this, Operation.INFERENCE);
        }
        return ret;
    }

    protected Map<String, INDArray> directExecHelper(Map<String, INDArray> placeholders, At at, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners, String ... outputs) {
        if (at == null) {
            at = At.defaultAt();
        }
        Preconditions.checkState((outputs != null && outputs.length > 0 ? 1 : 0) != 0, (String)"No outputs were specified");
        long threadId = Thread.currentThread().getId();
        if (!this.sessions.containsKey(threadId)) {
            log.info("Creating new InferenceSession for thread {}", (Object)threadId);
            this.sessions.put(threadId, new InferenceSession(this));
        }
        List<String> phNames = this.inputs();
        if (placeholders == null && phNames != null) {
            placeholders = this.placeholdersPerThread.get(Thread.currentThread().getId());
        }
        InferenceSession is = this.sessions.get(threadId);
        return is.output(outputs == null ? Collections.emptyList() : Arrays.asList(outputs), placeholders, batch, requiredActivations, activeListeners, at);
    }

    public SDVariable one(String name, int ... shape) {
        return this.one(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable one(String name, long ... shape) {
        return this.one(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable one(String name, DataType dataType, int ... shape) {
        return this.var(name, new ConstantInitScheme('f', 1.0), dataType, ArrayUtil.toLongArray((int[])shape));
    }

    public SDVariable one(String name, DataType dataType, long ... shape) {
        return this.var(name, new ConstantInitScheme('f', 1.0), dataType, shape);
    }

    public SDVariable zero(String name, long ... shape) {
        return this.zero(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable zero(String name, int ... shape) {
        return this.zero(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable zero(String name, DataType dataType, long ... shape) {
        return this.var(name, new ZeroInitScheme(), dataType, shape);
    }

    public SDVariable zero(String name, DataType dataType, int ... shape) {
        return this.var(name, new ZeroInitScheme(), dataType, ArrayUtil.toLongArray((int[])shape));
    }

    public SDVariable constant(@NonNull INDArray constant) {
        if (constant == null) {
            throw new NullPointerException("constant is marked @NonNull but is null");
        }
        return this.constant(this.getNewVarName(), constant);
    }

    public SDVariable constant(String name, @NonNull INDArray constant) {
        if (constant == null) {
            throw new NullPointerException("constant is marked @NonNull but is null");
        }
        Preconditions.checkState((!this.variables.containsKey(name) ? 1 : 0) != 0, (String)"Variable with name \"%s\" already exists", (Object)name);
        if (name == null || name.length() < 1) {
            name = this.getNewVarName();
        }
        if (constant.isView()) {
            try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                constant = constant.dup();
            }
        }
        SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null);
        name = v.getVarName();
        this.variables.put(name, Variable.builder().name(name).variable(v).build());
        this.constantArrays.put(name, new DeviceLocalNDArray(constant, true));
        return v;
    }

    @Deprecated
    public SDVariable constant(SDVariable value, long ... shape) {
        return this.constant(null, value, shape);
    }

    @Deprecated
    public SDVariable constant(String name, SDVariable value, long ... shape) {
        SDVariable ret = this.f().constant(value, shape);
        return this.updateVariableNameAndReference(ret, name);
    }

    public SDVariable placeHolder(@NonNull String name, DataType dataType, long ... shape) {
        if (name == null) {
            throw new NullPointerException("name is marked @NonNull but is null");
        }
        Preconditions.checkState((!this.variables.containsKey(name) ? 1 : 0) != 0, (String)"Variable already exists with name %s", (Object)name);
        SDVariable ret = new SDVariable(name, VariableType.PLACEHOLDER, this, shape, dataType, null);
        this.variables.put(name, Variable.builder().name(name).variable(ret).build());
        return ret;
    }

    public SDVariable var(@NonNull String name, @NonNull WeightInitScheme weightInitScheme, @NonNull DataType dataType, long ... shape) {
        if (name == null) {
            throw new NullPointerException("name is marked @NonNull but is null");
        }
        if (weightInitScheme == null) {
            throw new NullPointerException("weightInitScheme is marked @NonNull but is null");
        }
        if (dataType == null) {
            throw new NullPointerException("dataType is marked @NonNull but is null");
        }
        if (shape == null) {
            throw new NullPointerException("shape is marked @NonNull but is null");
        }
        return this.var(name, VariableType.VARIABLE, weightInitScheme, dataType, shape);
    }

    public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, DataType dataType, long ... shape) {
        if (name == null) {
            throw new NullPointerException("name is marked @NonNull but is null");
        }
        if (variableType == null) {
            throw new NullPointerException("variableType is marked @NonNull but is null");
        }
        if (this.variables.containsKey(name = name == null || name.length() < 1 ? this.getNewVarName() : this.generateNewVarName(name, 0))) {
            if (this.nameScopes.isEmpty()) {
                throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \"" + this.currentNameScope() + "\"");
            }
            throw new IllegalArgumentException("Another variable with the name " + name + " already exists.");
        }
        SDVariable ret = new SDVariable(name, variableType, this, shape, dataType, weightInitScheme);
        this.addVariable(ret);
        if (variableType == VariableType.PLACEHOLDER) {
            this.setOriginalPlaceHolderShape(name, shape);
            this.putShapeForVarName(name, shape);
        }
        return ret;
    }

    public SDVariable var(@NonNull String name, @NonNull LongShapeDescriptor shape, WeightInitScheme weightInitScheme) {
        if (name == null) {
            throw new NullPointerException("name is marked @NonNull but is null");
        }
        if (shape == null) {
            throw new NullPointerException("shape is marked @NonNull but is null");
        }
        return this.var(name, weightInitScheme, shape.dataType(), shape.getShape());
    }

    public SDVariable var(String name, DataType dataType, long ... shape) {
        Preconditions.checkNotNull((Object)(shape != null ? 1 : 0), (String)"Invalid shape: shape may not be null");
        if (Shape.isPlaceholderShape(shape)) {
            return this.placeHolder(name, dataType, shape);
        }
        return this.var(name, new ZeroInitScheme(), dataType, shape);
    }

    public SDVariable var(String name, LongShapeDescriptor shapeDesc) {
        Preconditions.checkNotNull((Object)(shapeDesc != null ? 1 : 0), (String)"Invalid shape: shape may not be null");
        return this.var(name, shapeDesc, new ZeroInitScheme());
    }

    public SDVariable var(String name, int ... shape) {
        return this.var(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable var(String name, long ... shape) {
        return this.var(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable var(String name, DataType dataType, int ... shape) {
        Preconditions.checkNotNull((Object)shape, (String)"Invalid shape: shape may not be null");
        if (Shape.isPlaceholderShape(shape)) {
            return this.placeHolder(name, dataType, ArrayUtil.toLongArray((int[])shape));
        }
        return this.var(name, new ZeroInitScheme(), dataType, ArrayUtil.toLongArray((int[])shape));
    }

    public SDVariable var(@NonNull SDVariable v) {
        if (v == null) {
            throw new NullPointerException("v is marked @NonNull but is null");
        }
        if (this.variables.containsKey(v.getVarName()) && this.variables.get(v.getVarName()).getVariable().getArr() != null) {
            return this.variables.get(v.getVarName()).getVariable();
        }
        if (v.getVarName() == null || v.getVarName().length() < 1) {
            throw new IllegalArgumentException("Name for variable must be defined");
        }
        VariableType vt = v.getVariableType();
        NDArraySupplierInitScheme s = null;
        switch (vt) {
            case VARIABLE: {
                s = new NDArraySupplierInitScheme(v.getArr());
            }
            case ARRAY: {
                SDVariable ret = new SDVariable(v.getVarName(), v.getVariableType(), this, v.getShape(), v.dataType(), s);
                return this.addVariable(ret);
            }
            case CONSTANT: {
                return this.constant(v.getVarName(), v.getArr());
            }
            case PLACEHOLDER: {
                return this.placeHolder(v.getVarName(), v.dataType(), v.placeholderShape());
            }
        }
        throw new RuntimeException("Unknown/not supported variable type: " + (Object)((Object)vt));
    }

    private String getNewVarName() {
        return this.generateNewVarName("sd_var", 0, false);
    }

    public SDVariable var(DataType dataType, int ... shape) {
        return this.var(this.getNewVarName(), dataType, shape);
    }

    public SDVariable var(DataType dataType, long ... shape) {
        return this.var(this.getNewVarName(), dataType, shape);
    }

    public SDVariable var(WeightInitScheme weightInitScheme, DataType dataType, long ... shape) {
        return this.var(this.getNewVarName(), weightInitScheme, dataType, shape);
    }

    public SDVariable var(INDArray arr) {
        return this.var(this.getNewVarName(), arr);
    }

    public SDVariable var(String name, @NonNull INDArray arr) {
        if (arr == null) {
            throw new NullPointerException("arr is marked @NonNull but is null");
        }
        if (this.variables.containsKey(name) && this.variables.get(name).getVariable().getArr() != null) {
            throw new IllegalArgumentException("Another variable with the name " + name + " already exists.");
        }
        Preconditions.checkState((boolean)arr.dataType().isFPType(), (String)"Cannot create variable with non-floating point type: provided array has datatype %s. Variables must be floating point type to be trainable by backpropagation.\nFor non floating point types, these should be created as placeholders or constants instead.", (Object)arr.dataType());
        if (name == null || name.length() < 1) {
            name = this.getNewVarName();
        }
        boolean duped = false;
        if (arr.isAttached()) {
            arr = arr.detach();
            duped = true;
        }
        if (arr.isView()) {
            arr = arr.dup();
            duped = true;
        }
        if (!duped) {
            for (DeviceLocalNDArray otherArr : this.variablesArrays.values()) {
                if (otherArr.get() != arr) continue;
                arr = arr.dup();
                break;
            }
        }
        SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType(), new NDArraySupplierInitScheme(arr));
        this.associateArrayWithVariable(arr, ret);
        this.addVariable(ret);
        if (this.getShapeForVarName(name) == null) {
            this.putShapeForVarName(name, arr.shape());
        }
        return ret;
    }

    public SDVariable convertToConstant(@NonNull SDVariable variable) {
        if (variable == null) {
            throw new NullPointerException("variable is marked @NonNull but is null");
        }
        this.convertToConstants(Collections.singletonList(variable));
        return variable;
    }

    public void convertToConstants(List<SDVariable> variables) {
        if (variables.size() == 0) {
            return;
        }
        boolean allConst = true;
        for (SDVariable variable : variables) {
            if (variable.getVariableType() == VariableType.CONSTANT) continue;
            allConst = false;
            Preconditions.checkState((variable.getVariableType() != VariableType.ARRAY ? 1 : 0) != 0, (String)"Cannot convert variable of type ARRAY to a constant: %s", (Object)variable);
        }
        if (allConst) {
            return;
        }
        this.sessions.clear();
        this.sameDiffFunctionInstances.remove(GRAD_FN_KEY);
        for (SDVariable variable : variables) {
            String n = variable.getVarName();
            INDArray arr = variable.getArr();
            Preconditions.checkNotNull((Object)arr, (String)"Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", (Object)variable);
            this.constantArrays.put(n, new DeviceLocalNDArray(arr, true));
            this.variablesArrays.remove(n);
            if (!this.placeholdersPerThread.isEmpty()) {
                for (Map map : this.placeholdersPerThread.values()) {
                    map.remove(n);
                }
            }
            variable.setVariableType(VariableType.CONSTANT);
        }
        if (this.trainingConfig != null && this.initializedTraining) {
            for (SDVariable v : variables) {
                Map<String, INDArray> m;
                GradientUpdater gu = this.updaterMap.remove(v.getVarName());
                Map<String, INDArray> map = m = gu == null ? null : gu.getState();
                if (m != null) {
                    for (INDArray iNDArray : m.values()) {
                        if (!iNDArray.closeable()) continue;
                        iNDArray.close();
                    }
                }
                if (this.trainingConfig.getDataSetFeatureMapping() != null && this.trainingConfig.getDataSetFeatureMapping().contains(v.getVarName())) {
                    ArrayList<String> newFM = new ArrayList<String>(this.trainingConfig.getDataSetFeatureMapping());
                    newFM.remove(v.getVarName());
                    this.trainingConfig.setDataSetFeatureMapping(newFM);
                }
                if (this.trainingConfig.getDataSetLabelMapping() != null && this.trainingConfig.getDataSetLabelMapping().contains(v.getVarName())) {
                    ArrayList<String> newLM = new ArrayList<String>(this.trainingConfig.getDataSetLabelMapping());
                    newLM.remove(v.getVarName());
                    this.trainingConfig.setDataSetLabelMapping(newLM);
                }
                if (this.trainingConfig.getDataSetFeatureMaskMapping() != null && this.trainingConfig.getDataSetFeatureMaskMapping().contains(v.getVarName())) {
                    ArrayList<String> newFMM = new ArrayList<String>(this.trainingConfig.getDataSetFeatureMaskMapping());
                    newFMM.remove(v.getVarName());
                    this.trainingConfig.setDataSetFeatureMaskMapping(newFMM);
                }
                if (this.trainingConfig.getDataSetLabelMaskMapping() == null || !this.trainingConfig.getDataSetLabelMaskMapping().contains(v.getVarName())) continue;
                ArrayList<String> newLMM = new ArrayList<String>(this.trainingConfig.getDataSetLabelMaskMapping());
                newLMM.remove(v.getVarName());
                this.trainingConfig.setDataSetLabelMaskMapping(newLMM);
            }
        }
    }

    public SDVariable convertToVariable(@NonNull SDVariable constant) {
        if (constant == null) {
            throw new NullPointerException("constant is marked @NonNull but is null");
        }
        Preconditions.checkState((boolean)constant.dataType().isFPType(), (String)"Only floating point SDVariables can be converted to variables, datatype of %s is %s", (Object)constant.getVarName(), (Object)constant.dataType());
        this.convertToVariables(Collections.singletonList(constant));
        return constant;
    }

    public void convertToVariables(@NonNull List<SDVariable> constants) {
        if (constants == null) {
            throw new NullPointerException("constants is marked @NonNull but is null");
        }
        if (constants.size() == 0) {
            return;
        }
        boolean allConst = true;
        for (SDVariable variable : constants) {
            if (variable.getVariableType() != VariableType.VARIABLE) {
                allConst = false;
            }
            Preconditions.checkState((variable.getVariableType() != VariableType.ARRAY ? 1 : 0) != 0, (String)"Cannot convert variable of type ARRAY to a variable: %s", (Object)variable);
        }
        if (allConst) {
            return;
        }
        this.sessions.clear();
        this.sameDiffFunctionInstances.remove(GRAD_FN_KEY);
        for (SDVariable variable : constants) {
            String n = variable.getVarName();
            INDArray arr = variable.getArr();
            Preconditions.checkNotNull((Object)arr, (String)"Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", (Object)variable);
            this.variablesArrays.put(n, new DeviceLocalNDArray(arr, true));
            this.constantArrays.remove(n);
            if (!this.placeholdersPerThread.isEmpty()) {
                for (Map<String, INDArray> m : this.placeholdersPerThread.values()) {
                    m.remove(n);
                }
            }
            variable.setVariableType(VariableType.VARIABLE);
        }
        if (this.trainingConfig != null && this.initializedTraining) {
            for (SDVariable v : constants) {
                if (this.updaterMap.containsKey(v.getVarName())) continue;
                INDArray arr = v.getArr();
                long thisSize = this.trainingConfig.getUpdater().stateSize(arr.length());
                if (thisSize > 0L) {
                    INDArray stateArr = Nd4j.create(arr.dataType(), 1L, thisSize);
                    GradientUpdater u = this.trainingConfig.getUpdater().instantiate(stateArr, true);
                    this.updaterMap.put(v.getVarName(), u);
                    continue;
                }
                GradientUpdater u = this.trainingConfig.getUpdater().instantiate((INDArray)null, true);
                this.updaterMap.put(v.getVarName(), u);
            }
        }
    }

    public void convertDataTypes(@NonNull Map<String, DataType> dataTypeMap) {
        if (dataTypeMap == null) {
            throw new NullPointerException("dataTypeMap is marked @NonNull but is null");
        }
        if (dataTypeMap.isEmpty()) {
            return;
        }
        for (Map.Entry<String, DataType> e : dataTypeMap.entrySet()) {
            String s = e.getKey();
            Preconditions.checkState((boolean)this.variables.containsKey(s), (String)"Cannot change datatype of variable \"%s\": No variable with this name exists", (Object)s);
            SDVariable v = this.variables.get(s).getVariable();
            Preconditions.checkState((v.getVariableType() != VariableType.ARRAY ? 1 : 0) != 0, (String)"Cannot change datatype of ARRAY type variable \"%s\": datatype of ARRAY type variables is determined by the datatypes of their inputs plus corresponding ");
            if (v.getVariableType() == VariableType.PLACEHOLDER) continue;
            Preconditions.checkState((v.dataType().isNumerical() == e.getValue().isNumerical() ? 1 : 0) != 0, (String)"Cannot convert variables between numerical and non-numerical types: attempting to convert variable \"%s\" from %s to %s", (Object)e.getKey(), (Object)v.dataType(), (Object)e.getValue());
        }
        boolean anyChanged = false;
        for (Map.Entry<String, DataType> e : dataTypeMap.entrySet()) {
            String s = e.getKey();
            DataType d = e.getValue();
            SDVariable v = this.variables.get(s).getVariable();
            if (v.dataType() == d) continue;
            v.setDataType(d);
            switch (v.getVariableType()) {
                case VARIABLE: {
                    DeviceLocalNDArray dl = this.variablesArrays.remove(e.getKey());
                    INDArray arr = dl.get();
                    INDArray newArr = arr.castTo(d);
                    this.variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr, true));
                    break;
                }
                case CONSTANT: {
                    DeviceLocalNDArray dl2 = this.constantArrays.remove(e.getKey());
                    INDArray arr2 = dl2.get();
                    INDArray newArr2 = arr2.castTo(d);
                    this.constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2, true));
                    break;
                }
                case PLACEHOLDER: {
                    Map<String, INDArray> m = this.placeholdersPerThread.get(Thread.currentThread().getId());
                    if (m == null || !m.containsKey(e.getKey())) break;
                    m.put(e.getKey(), m.get(e.getKey()).castTo(d));
                    break;
                }
                default: {
                    throw new IllegalStateException("Cannot convert array type variable");
                }
            }
            anyChanged = true;
        }
        if (anyChanged) {
            this.sessions.clear();
            this.calculateOutputDataTypes(true);
        }
    }

    public void renameVariable(String from, String to) {
        Variable var;
        ArrayList<String> newCDs;
        SameDiffOp op;
        Preconditions.checkState((boolean)this.variables.containsKey(from), (String)"Cannot rename variable \"%s\": no variable with this name exists", (Object)from);
        Preconditions.checkState((!this.variables.containsKey(to) ? 1 : 0) != 0, (String)"Cannot rename variable \"%s\" to name \"%s\": a variable with name \"%s\" already exists", (Object)from, (Object)to, (Object)to);
        Variable v = this.variables.get(from);
        v.setName(to);
        v.getVariable().setVarName(to);
        if (v.getInputsForOp() != null) {
            for (String opName : v.getInputsForOp()) {
                op = this.ops.get(opName);
                ArrayList<String> newInputs = new ArrayList<String>(op.getInputsToOp());
                while (newInputs.contains(from)) {
                    newInputs.set(newInputs.indexOf(from), to);
                }
                op.setInputsToOp(newInputs);
            }
        }
        if (v.getControlDepsForOp() != null) {
            for (String opName : v.getControlDepsForOp()) {
                op = this.ops.get(opName);
                newCDs = new ArrayList<String>(op.getControlDeps());
                while (newCDs.contains(from)) {
                    newCDs.set(newCDs.indexOf(from), to);
                }
                op.setControlDeps(newCDs);
            }
        }
        if (v.getControlDepsForVar() != null) {
            for (String varName : v.getControlDepsForVar()) {
                var = this.variables.get(varName);
                newCDs = new ArrayList<String>(var.getControlDeps());
                while (newCDs.contains(from)) {
                    newCDs.set(newCDs.indexOf(from), to);
                }
                var.setControlDeps(newCDs);
            }
        }
        if (v.getControlDeps() != null) {
            for (String varName : v.getControlDeps()) {
                var = this.variables.get(varName);
                ArrayList<String> newCDsFor = new ArrayList<String>(var.getControlDepsForVar());
                while (newCDsFor.contains(from)) {
                    newCDsFor.set(newCDsFor.indexOf(from), to);
                }
                var.setControlDepsForVar(newCDsFor);
            }
        }
        if (v.getOutputOfOp() != null) {
            SameDiffOp op2 = this.ops.get(v.getOutputOfOp());
            ArrayList<String> newOuts = new ArrayList<String>(op2.getOutputsOfOp());
            while (newOuts.contains(from)) {
                newOuts.set(newOuts.indexOf(from), to);
            }
            op2.setOutputsOfOp(newOuts);
        }
        this.variables.remove(from);
        this.variables.put(to, v);
        if (this.trainingConfig != null) {
            ArrayList<String> l;
            if (this.trainingConfig.getDataSetFeatureMapping() != null && this.trainingConfig.getDataSetFeatureMapping().contains(from)) {
                l = new ArrayList<String>(this.trainingConfig.getDataSetFeatureMapping());
                while (l.contains(from)) {
                    l.set(l.indexOf(from), to);
                }
                this.trainingConfig.setDataSetFeatureMapping(l);
            }
            if (this.trainingConfig.getDataSetLabelMapping() != null && this.trainingConfig.getDataSetLabelMapping().contains(from)) {
                l = new ArrayList<String>(this.trainingConfig.getDataSetLabelMapping());
                while (l.contains(from)) {
                    l.set(l.indexOf(from), to);
                }
                this.trainingConfig.setDataSetLabelMapping(l);
            }
            if (this.trainingConfig.getDataSetFeatureMaskMapping() != null && this.trainingConfig.getDataSetFeatureMaskMapping().contains(from)) {
                l = new ArrayList<String>(this.trainingConfig.getDataSetFeatureMaskMapping());
                while (l.contains(from)) {
                    l.set(l.indexOf(from), to);
                }
                this.trainingConfig.setDataSetFeatureMaskMapping(l);
            }
            if (this.trainingConfig.getDataSetLabelMaskMapping() != null && this.trainingConfig.getDataSetLabelMaskMapping().contains(from)) {
                l = new ArrayList<String>(this.trainingConfig.getDataSetLabelMaskMapping());
                while (l.contains(from)) {
                    l.set(l.indexOf(from), to);
                }
                this.trainingConfig.setDataSetLabelMaskMapping(l);
            }
            if (this.trainingConfig.getLossVariables() != null && this.trainingConfig.getLossVariables().contains(from)) {
                l = new ArrayList<String>(this.trainingConfig.getLossVariables());
                while (l.contains(from)) {
                    l.set(l.indexOf(from), to);
                }
                this.trainingConfig.setLossVariables(l);
            }
        }
        for (SameDiff sd : this.sameDiffFunctionInstances.values()) {
            if (!sd.hasVariable(from)) continue;
            sd.renameVariable(from, to);
        }
    }

    public void removeArgFromOp(String varName, DifferentialFunction function) {
        SDVariable[] args = function.args();
        for (int i = 0; i < args.length; ++i) {
            if (!args[i].getVarName().equals(varName)) continue;
            List<String> reverseArgs = this.ops.get(function.getOwnName()).getInputsToOp();
            ArrayList<String> newArgs = new ArrayList<String>(args.length - 1);
            for (int arg = 0; arg < args.length; ++arg) {
                if (reverseArgs.get(arg).equals(varName)) continue;
                newArgs.add(reverseArgs.get(arg));
            }
            this.ops.get(function.getOwnName()).setInputsToOp(newArgs);
            break;
        }
    }

    public SDVariable getVariable(String name) {
        Variable v = this.variables.get(name);
        return v == null ? null : v.getVariable();
    }

    public boolean hasVariable(String name) {
        return this.variables.containsKey(name);
    }

    public SDVariable getGradForVariable(String varName) {
        Preconditions.checkState((boolean)this.variables.containsKey(varName), (String)"No variable with name \"%s\" exists", (Object)varName);
        SDVariable v = this.getVariable(varName);
        Preconditions.checkState((boolean)v.dataType().isFPType(), (String)"Cannot get gradient of %s variable \"%s\": only floating point variables have gradients", (Object)varName, (Object)v.dataType());
        if (this.variables.containsKey(varName) && this.variables.get(varName).getGradient() != null) {
            return this.variables.get(varName).getGradient();
        }
        if (this.sameDiffFunctionInstances.containsKey(GRAD_FN_KEY) && this.sameDiffFunctionInstances.get((Object)GRAD_FN_KEY).variables.containsKey(varName)) {
            return this.sameDiffFunctionInstances.get((Object)GRAD_FN_KEY).variables.get(varName).getGradient();
        }
        return null;
    }

    public boolean variableHasGradient(String varName) {
        Preconditions.checkState((boolean)this.variables.containsKey(varName), (String)"No variable with name \"%s\" exists", (Object)varName);
        SDVariable v = this.getVariable(varName);
        if (!v.dataType().isFPType() || v.isConstant()) {
            return false;
        }
        return this.getGradForVariable(varName) != null;
    }

    public void setGradientForVariableName(String variableName, SDVariable variable) {
        Preconditions.checkState((boolean)this.variables.containsKey(variableName), (String)"No variable exists with name \"%s\"", (Object)variableName);
        if (variable == null) {
            throw new ND4JIllegalStateException("Unable to set null gradient for variable name " + variableName);
        }
        this.variables.get(variableName).setGradient(variable);
    }

    public void setForwardVariableForVarName(String varName, SDVariable forwardVariable) {
        this.forwardVarForGrad.put(varName, forwardVariable);
    }

    public SDVariable grad(String varName) {
        if (!this.sameDiffFunctionInstances.containsKey(GRAD_FN_KEY)) {
            throw new IllegalStateException("Unable to obtain gradient. Please run execBackwards() first.");
        }
        SameDiff grad = this.getFunction(GRAD_FN_KEY);
        SDVariable var = grad.getVariable(varName);
        return this.getFunction(GRAD_FN_KEY).getGradForVariable(var.getVarName());
    }

    public SDVariable scalar(String name, double value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable scalar(String name, float value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable scalar(String name, int value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable scalar(String name, long value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable scalar(String name, DataType dataType, Number value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(dataType, value));
            return sDVariable;
        }
    }

    public SDVariable constant(double value) {
        return this.constant(null, value);
    }

    public SDVariable constant(String name, double value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable constant(float value) {
        return this.constant((String)null, value);
    }

    public SDVariable constant(String name, float value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable constant(int value) {
        return this.constant((String)null, value);
    }

    public SDVariable constant(String name, int value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable constant(long value) {
        return this.constant((String)null, value);
    }

    public SDVariable constant(String name, long value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable constant(String name, DataType dataType, Number value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(dataType, value));
            return sDVariable;
        }
    }

    public SDVariable addVariable(SDVariable variable) {
        Preconditions.checkState((variable.getSameDiff() == this ? 1 : 0) != 0, (String)"Samediff instance must be the same.");
        if (this.variables.containsKey(variable.getVarName()) && !this.variables.get(variable.getVarName()).getVariable().equals(variable)) {
            throw new IllegalArgumentException("Variable with name \"" + variable.getVarName() + "\" already exists");
        }
        Preconditions.checkState((variable.getSameDiff() == this ? 1 : 0) != 0, (String)"Same diff instance for variable must be the same!");
        this.variables.put(variable.getVarName(), Variable.builder().name(variable.getVarName()).variable(variable).build());
        return variable;
    }

    public SDVariable[] generateOutputVariableForOp(DifferentialFunction function, String baseName, boolean isImport) {
        if (baseName == null || baseName.isEmpty() && this.getBaseNameForFunction(function) != null) {
            baseName = this.getBaseNameForFunction(function);
        }
        if (baseName == null) {
            baseName = function.getOwnName();
        }
        if (baseName == null) {
            baseName = function.opName();
        }
        List<DataType> outputDataTypes = null;
        if (!isImport) {
            ArrayList<DataType> inputDataTypes = new ArrayList<DataType>();
            List<String> fnInputs = this.ops.get(function.getOwnName()).getInputsToOp();
            if (fnInputs != null) {
                for (String var : fnInputs) {
                    inputDataTypes.add(this.variables.get(var).getVariable().dataType());
                }
            }
            outputDataTypes = function.calculateOutputDataTypes(inputDataTypes);
        }
        if (function instanceof CustomOp) {
            CustomOp customOp = (CustomOp)((Object)function);
            int num_outputs = function.getNumOutputs();
            if (num_outputs <= 0) {
                CustomOpDescriptor descriptor = customOp.getDescriptor();
                if (descriptor != null) {
                    num_outputs = descriptor.getNumOutputs();
                }
                if (num_outputs <= 0) {
                    throw new ND4UnresolvedOutputVariables("Could not determine number of output variables for op " + function.getOwnName() + " - " + function.getClass().getSimpleName() + ". Ops can override getNumOutputs() to specify number of outputs if required");
                }
            }
            SDVariable[] ret = new SDVariable[num_outputs];
            Preconditions.checkState((isImport || outputDataTypes != null && outputDataTypes.size() == num_outputs ? 1 : 0) != 0, (String)"Incorrect number of output datatypes: got %s but expected datatypes for %s outputs - %s (op: %s)", (Object)(outputDataTypes == null ? null : Integer.valueOf(outputDataTypes.size())), (Object)num_outputs, outputDataTypes, (Object)function.getClass().getSimpleName());
            for (int i = 0; i < ret.length; ++i) {
                SDVariable var;
                SDVariable sDVariable = var = i == 0 ? this.getVariable(baseName) : this.getVariable(baseName + ":" + i);
                if (var == null) {
                    DataType dataType = isImport ? null : outputDataTypes.get(i);
                    var = this.var(this.generateNewVarName(baseName, i), VariableType.ARRAY, null, dataType, (long[])null);
                }
                var.setCreator(function);
                ret[i] = var;
            }
            if (this.getOutputsForOp(function) == null) {
                this.addOutgoingFor(ret, function);
            }
            return ret;
        }
        if (function instanceof BaseOp) {
            DataType dataType;
            SDVariable[] ret = new SDVariable[1];
            SDVariable checkGet = this.getVariable(baseName);
            int ordering = 99;
            SDVariable[] args = function.args();
            if (args != null && args.length > 0 && function.args()[0].getArr() != null) {
                ordering = function.args()[0].getArr().ordering();
            }
            if (checkGet == null) {
                dataType = outputDataTypes.get(0);
                checkGet = this.var(baseName, VariableType.ARRAY, null, dataType, (long[])null);
            }
            if (checkGet == null) {
                dataType = outputDataTypes.get(0);
                checkGet = this.var(baseName, VariableType.ARRAY, null, dataType, (long[])null);
            }
            checkGet.setCreator(function);
            ret[0] = checkGet;
            if (this.getOutputsForOp(function) == null) {
                this.addOutgoingFor(ret, function);
            }
            return ret;
        }
        throw new RuntimeException("Unknown op type: " + function.getClass());
    }

    public SDVariable[] generateOutputVariableForOp(DifferentialFunction function) {
        return this.generateOutputVariableForOp(function, function.getOwnName() != null ? function.getOwnName() : function.opName(), false);
    }

    public SameDiff getFunction(String functionName) {
        return this.sameDiffFunctionInstances.get(functionName);
    }

    @Deprecated
    public While whileStatement(SameDiffConditional sameDiffConditional, SameDiffFunctionDefinition conditionBody, SameDiffFunctionDefinition loopBody, SDVariable[] inputVars) {
        return While.builder().inputVars(inputVars).condition(conditionBody).predicate(sameDiffConditional).trueBody(loopBody).parent(this).blockName("while-" + UUID.randomUUID().toString()).build();
    }

    @Deprecated
    public If ifStatement(SameDiffConditional conditional, SameDiffFunctionDefinition conditionBody, SameDiffFunctionDefinition trueBody, SameDiffFunctionDefinition falseBody, SDVariable[] inputVars) {
        return If.builder().conditionBody(conditionBody).falseBody(falseBody).trueBody(trueBody).predicate(conditional).inputVars(inputVars).parent(this).blockName("if-" + UUID.randomUUID().toString()).build();
    }

    public TensorArray tensorArray(DataType dataType) {
        TensorArray ta = new TensorArray(this, dataType);
        SDVariable[] outVars = ta.outputVariables();
        return ta;
    }

    public SDVariable invokeFunctionOn(String functionName, SameDiff with) {
        SameDiff instance = this.sameDiffFunctionInstances.get(functionName);
        SDVariable ret = instance.invokeGraphOn(with);
        return ret;
    }

    public SameDiff defineFunction(String function, SameDiffFunctionDefinition functionDefinition, SDVariable[] variables) {
        if (!this.sameDiffFunctionInstances.containsKey(function)) {
            SameDiff sub;
            this.child = sub = SameDiff.create();
            sub.parent = this;
            SDVariable[] ret = new SDVariable[variables.length];
            for (int i = 0; i < ret.length; ++i) {
                ret[i] = sub.var(variables[i]);
            }
            functionDefinition.define(sub, null, ret);
            this.sameDiffFunctionInstances.put(function, sub);
        }
        this.child = null;
        return this.sameDiffFunctionInstances.get(function);
    }

    public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition) {
        this.defineFunction(function, functionDefinition, new LinkedHashMap<String, INDArray>());
    }

    public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition, Map<String, INDArray> inputs) {
        if (!this.sameDiffFunctionInstances.containsKey(function)) {
            SameDiff sub = SameDiff.create();
            functionDefinition.define(sub, inputs, null);
            this.sameDiffFunctionInstances.put(function, sub);
        }
    }

    @Deprecated
    public INDArray execAndEndResult() {
        List<String> outputs = this.outputs();
        Preconditions.checkState((outputs.size() == 1 ? 1 : 0) != 0, (String)"Method can only be used with SameDiff instances with a single output");
        long tid = Thread.currentThread().getId();
        Map<String, INDArray> placeholders = this.placeholdersPerThread.get(tid);
        return this.execSingle(placeholders, outputs.get(0));
    }

    public void execBackwards(Map<String, INDArray> placeholders, Operation op) {
        this.execBackwards(placeholders, op, null, Collections.emptyList(), Collections.emptyList());
    }

    public void execBackwards(Map<String, INDArray> placeholders) {
        this.execBackwards(placeholders, Operation.INFERENCE);
    }

    protected void execBackwards(Map<String, INDArray> placeholders, Operation op, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners) {
        if (this.getFunction(GRAD_FN_KEY) == null) {
            this.createGradFunction();
        }
        HashSet<String> varGradNames = new HashSet<String>();
        for (Variable v : this.variables.values()) {
            SDVariable g;
            if (v.getVariable().getVariableType() != VariableType.VARIABLE || (g = v.getVariable().gradient()) == null) continue;
            varGradNames.add(g.getVarName());
        }
        if (!this.listeners.isEmpty()) {
            varGradNames.addAll(this.lossVariables);
        }
        if (varGradNames.isEmpty()) {
            log.warn("Skipping gradient execution (backward pass) - no variables to be calculated (graph does not contain any VARIABLE type SDVariables).\nIf gradients for other variables (such as placeholders) are required, use execBackwards(Map, List) instead");
        }
        ArrayList<String> vargradNamesList = new ArrayList<String>(varGradNames);
        this.execBackwards(placeholders, vargradNamesList, op, batch, requiredActivations, activeListeners);
    }

    public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, Operation op, String ... variableGradNamesList) {
        return this.execBackwards(placeholders, Arrays.asList(variableGradNamesList), op, null, Collections.emptyList(), Collections.emptyList());
    }

    public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, String ... variableGradNamesList) {
        return this.execBackwards(placeholders, Operation.INFERENCE, variableGradNamesList);
    }

    public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList, Operation operation) {
        return this.execBackwards(placeholders, variableGradNamesList, operation, null, Collections.emptyList(), Collections.emptyList());
    }

    public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList) {
        return this.execBackwards(placeholders, variableGradNamesList, Operation.INFERENCE);
    }

    protected Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList, Operation operation, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners) {
        if (this.getFunction(GRAD_FN_KEY) == null) {
            this.createGradFunction();
        }
        log.trace("About to execute backward function");
        if (variableGradNamesList.isEmpty()) {
            log.warn("Skipping gradient calculation (backward pass) - no variables to be calculated (variableGradNamesList is empty)");
            return Collections.emptyMap();
        }
        SameDiff sd = this.sameDiffFunctionInstances.get(GRAD_FN_KEY);
        sd.listeners.clear();
        sd.listeners.addAll(activeListeners);
        At at = new At(0, 0, 0, Thread.currentThread().getId(), operation);
        if (this.trainingConfig != null) {
            at.setIteration(this.trainingConfig.getIterationCount());
            at.setEpoch(this.trainingConfig.getEpochCount());
        }
        return sd.directExecHelper(placeholders, at, batch, requiredActivations, activeListeners, variableGradNamesList.toArray(new String[0]));
    }

    public boolean hasGradientFunction() {
        return this.sameDiffFunctionInstances.containsKey(GRAD_FN_KEY);
    }

    public void createGradFunction() {
        this.createGradFunction(null);
    }

    public void createGradFunction(final String ... variablesRequiringGradients) {
        if (this.lossVariables.isEmpty()) {
            if (this.trainingConfig != null && this.trainingConfig.getLossVariables() != null && !this.trainingConfig.getLossVariables().isEmpty()) {
                this.lossVariables.addAll(this.trainingConfig.getLossVariables());
            } else {
                String[] outputs = this.outputs();
                if (outputs.size() == 1) {
                    String outName = outputs.get(0);
                    String opName = this.variables.get(outName).getOutputOfOp();
                    if (opName == null || !(this.ops.get(opName).getOp() instanceof ExternalErrorsFunction)) {
                        log.info("Inferring output \"{}\" as loss variable as none were previously set. Use SameDiff.setLossVariables() to override", outputs.get(0));
                    }
                    this.lossVariables.add((String)outputs.get(0));
                }
            }
        }
        Preconditions.checkState((!this.lossVariables.isEmpty() ? 1 : 0) != 0, (String)"Cannot create gradient function: No loss variables (variables to minimize) have been specified. Loss variables are the variables that represent the loss/cost/score to be minimized during training, and that all gradients are calculated with respect to.\n Losses can be specified either in TrainingConfiguration (Builder.minimize(...)) or via SameDiff.setLossVariables()/addLossVariable()");
        if (log.isTraceEnabled()) {
            log.trace("Defining function \"grad\"");
        }
        if (variablesRequiringGradients != null && variablesRequiringGradients.length > 0) {
            for (String s : variablesRequiringGradients) {
                Preconditions.checkArgument((boolean)this.variables.containsKey(s), (String)"Cannot ensure gradient exists for variable: no variable with name \"%s\" exists", (Object)s);
                DataType dt = this.variables.get(s).getVariable().dataType();
                Preconditions.checkState((boolean)dt.isFPType(), (String)"Cannot ensure gradient exists for variable \"%s\": variable is not a floating point SDVariable. Only floating point SDVariables have gradients defined - variable has type %s", (Object)s, (Object)dt);
            }
        }
        final SameDiff outer = this;
        this.defineFunction(GRAD_FN_KEY, new SameDiffFunctionDefinition(){

            @Override
            public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
                List<String> inputsToOp;
                ArrayList allFunctions;
                if (SameDiff.this.debugMode) {
                    sameDiff.enableDebugMode();
                }
                outer.invokeGraphOn(sameDiff);
                if (SameDiff.this.debugMode) {
                    Preconditions.checkState((boolean)sameDiff.ops.keySet().equals(SameDiff.this.ops.keySet()), (String)"ops keysets not equal");
                }
                if ((allFunctions = new ArrayList(sameDiff.ops.values())).isEmpty()) {
                    throw new ND4JIllegalStateException("No ops found!");
                }
                for (SameDiffOp op : allFunctions) {
                    SDVariable[] outputs;
                    SDVariable[] args;
                    DifferentialFunction func = op.getOp();
                    for (SDVariable arg : args = func.args()) {
                        arg.setSameDiff(sameDiff);
                    }
                    for (SDVariable output : outputs = func.outputVariables()) {
                        output.setSameDiff(sameDiff);
                    }
                    func.setSameDiff(sameDiff);
                }
                ArrayList<Object> finalOutputs = new ArrayList<Object>(SameDiff.this.lossVariables.size());
                SDVariable initialGrad = sameDiff.var("one-var", Nd4j.scalar(1.0f));
                for (String s : SameDiff.this.lossVariables) {
                    Preconditions.checkNotNull((Object)s, (String)"Encountered null value in loss variables. Null loss variables are not allowed. Use SameDiff.setLossVariables with non-null array names to fix");
                    Preconditions.checkState((boolean)SameDiff.this.variables.containsKey(s), (String)"Specified loss function variable \"%s\" does not exist", (Object)s);
                    Object v = ((Variable)SameDiff.this.variables.get(s)).getVariable();
                    Preconditions.checkState((boolean)((SDVariable)v).dataType().isFPType(), (String)"Specified loss function variable \"%s\" is not a floatingpoint variable (datatype: %s). Only floating point variables may be used as loss function variable", (Object)s, (Object)((SDVariable)v).dataType());
                    v = ((SDVariable)v).sum(new int[0]);
                    if (((SDVariable)v).dataType() == initialGrad.dataType()) {
                        sameDiff.setGradientForVariableName(((SDVariable)v).getVarName(), initialGrad);
                    } else {
                        sameDiff.setGradientForVariableName(((SDVariable)v).getVarName(), initialGrad.castTo(((SDVariable)v).dataType()));
                    }
                    if (finalOutputs.contains(v)) {
                        log.warn("Loss function variable \"{}\" appears multiple times in list of loss variables - using only first instance", (Object)s);
                        continue;
                    }
                    finalOutputs.add(v);
                }
                if (log.isTraceEnabled()) {
                    String s;
                    Object[] initialOutputsStr = ((SameDiffOp)allFunctions.get(allFunctions.size() - 1)).getOp().outputVariablesNames();
                    s = initialOutputsStr == null ? "null" : Arrays.toString(initialOutputsStr);
                    log.trace("Defining backward function: initial outputs {}", (Object)s);
                }
                HashSet<String> allFpVarsConnectedToLoss = new HashSet<String>();
                LinkedList<String> toProcess = new LinkedList<String>();
                for (String s : SameDiff.this.lossVariables) {
                    if (toProcess.contains(s)) continue;
                    toProcess.add(s);
                }
                while (!toProcess.isEmpty()) {
                    Variable v;
                    String next = (String)toProcess.remove();
                    if (allFpVarsConnectedToLoss.contains(next) || !(v = (Variable)SameDiff.this.variables.get(next)).getVariable().dataType().isFPType()) continue;
                    allFpVarsConnectedToLoss.add(v.getName());
                    if (v.getOutputOfOp() == null) continue;
                    String opName = v.getOutputOfOp();
                    SameDiffOp op = (SameDiffOp)SameDiff.this.ops.get(opName);
                    List<String> opInputs = op.getInputsToOp();
                    if (opInputs == null) continue;
                    for (String s : opInputs) {
                        Variable inputVar = (Variable)SameDiff.this.variables.get(s);
                        if (!inputVar.getVariable().dataType().isFPType()) continue;
                        toProcess.add(s);
                    }
                }
                HashSet minimalSubgraphVars = new HashSet(allFpVarsConnectedToLoss);
                LinkedList<String> leafFPVars = new LinkedList<String>();
                for (String s : allFpVarsConnectedToLoss) {
                    boolean isUserRequested;
                    Variable v = (Variable)SameDiff.this.variables.get(s);
                    if (v.getVariable().getVariableType() == VariableType.ARRAY) {
                        String opName = v.getOutputOfOp();
                        SameDiffOp op = (SameDiffOp)SameDiff.this.ops.get(opName);
                        inputsToOp = op.getInputsToOp();
                        boolean anyInputsInSubgraph = false;
                        if (inputsToOp != null) {
                            for (String string : inputsToOp) {
                                if (!allFpVarsConnectedToLoss.contains(string)) continue;
                                anyInputsInSubgraph = true;
                                break;
                            }
                        }
                        if (!anyInputsInSubgraph) {
                            leafFPVars.add(s);
                        }
                    }
                    VariableType vt = v.getVariable().getVariableType();
                    boolean bl = isUserRequested = variablesRequiringGradients != null && ArrayUtils.contains((Object[])variablesRequiringGradients, (Object)s);
                    if (vt != VariableType.CONSTANT && vt != VariableType.PLACEHOLDER || isUserRequested) continue;
                    leafFPVars.add(s);
                }
                while (!leafFPVars.isEmpty()) {
                    String nextLeaf = (String)leafFPVars.remove();
                    Variable v = (Variable)SameDiff.this.variables.get(nextLeaf);
                    minimalSubgraphVars.remove(nextLeaf);
                    List<String> inputsTo = v.getInputsForOp();
                    if (inputsTo == null || inputsTo.isEmpty()) continue;
                    for (String opName : inputsTo) {
                        List<String> list;
                        SameDiffOp op = (SameDiffOp)SameDiff.this.ops.get(opName);
                        List<String> inputsToOp2 = op.getInputsToOp();
                        boolean anyPresent = false;
                        for (String string : inputsToOp2) {
                            if (!minimalSubgraphVars.contains(string) && (variablesRequiringGradients == null || !ArrayUtils.contains((Object[])variablesRequiringGradients, (Object)string))) continue;
                            anyPresent = true;
                            break;
                        }
                        if (anyPresent || (list = op.getOutputsOfOp()) == null) continue;
                        for (String s3 : list) {
                            if (leafFPVars.contains(s3)) continue;
                            leafFPVars.add(s3);
                        }
                    }
                }
                Preconditions.checkState((!minimalSubgraphVars.isEmpty() ? 1 : 0) != 0, (String)"Cannot differentiate graph relative to the specified loss function variables %s: graph does not contain any trainable SDVariables (floating point VARIABLE type SDVariables) that the loss function depend on.", (Object)SameDiff.this.lossVariables);
                LinkedList<String> availableForDiff = new LinkedList<String>();
                for (Object lossVar : finalOutputs) {
                    Variable v = (Variable)sameDiff.variables.get(((SDVariable)lossVar).getVarName());
                    if (v.getOutputOfOp() == null) continue;
                    String opName = v.getOutputOfOp();
                    availableForDiff.add(opName);
                }
                HashMap prerequisites = new HashMap();
                for (String var : minimalSubgraphVars) {
                    Variable variable = (Variable)SameDiff.this.variables.get(var);
                    List<String> inputsForOp = variable.getInputsForOp();
                    if (inputsForOp == null) continue;
                    ArrayList<String> req = new ArrayList<String>();
                    for (String string : inputsForOp) {
                        SameDiffOp sameDiffOp = (SameDiffOp)SameDiff.this.ops.get(string);
                        List<String> opOutputs = sameDiffOp.getOutputsOfOp();
                        boolean anyOpOutputsRequired = false;
                        if (opOutputs != null) {
                            for (String s : opOutputs) {
                                if (!minimalSubgraphVars.contains(s)) continue;
                                anyOpOutputsRequired = true;
                                break;
                            }
                        }
                        if (!anyOpOutputsRequired) continue;
                        req.add(string);
                    }
                    prerequisites.put(variable.getName(), req);
                }
                HashSet<String> differentiatedOps = new HashSet<String>();
                while (!availableForDiff.isEmpty()) {
                    List<Object> outputsOfOp;
                    String dfName = (String)availableForDiff.remove();
                    DifferentialFunction df = ((SameDiffOp)sameDiff.ops.get(dfName)).getOp();
                    if (df instanceof GradientBackwardsMarker) {
                        SameDiffOp op = (SameDiffOp)sameDiff.ops.get(df.getOwnName());
                        inputsToOp = op.getInputsToOp();
                        outputsOfOp = Collections.emptyList();
                    } else {
                        inputsToOp = ((SameDiffOp)sameDiff.ops.get(df.getOwnName())).getInputsToOp();
                        outputsOfOp = ((SameDiffOp)sameDiff.ops.get(df.getOwnName())).getOutputsOfOp();
                    }
                    ArrayList<SDVariable> grads = new ArrayList<SDVariable>();
                    for (String string : outputsOfOp) {
                        SDVariable g;
                        SDVariable v = sameDiff.getVariable(string);
                        SDVariable sDVariable = g = v.hasGradient() ? v.gradient() : null;
                        if (g == null) {
                            if (!v.dataType().isFPType()) {
                                grads.add(null);
                                continue;
                            }
                            SDVariable gTemp = sameDiff.zerosLike(v);
                            grads.add(gTemp);
                            continue;
                        }
                        grads.add(g);
                    }
                    List<SDVariable> list = df.diff(grads);
                    differentiatedOps.add(df.getOwnName());
                    for (String s : inputsToOp) {
                        Variable v = (Variable)sameDiff.variables.get(s);
                        String opName = v.getOutputOfOp();
                        if (opName == null || differentiatedOps.contains(opName)) continue;
                        boolean isRequiredOp = false;
                        SameDiffOp op = (SameDiffOp)SameDiff.this.ops.get(opName);
                        if (op.getInputsToOp() != null) {
                            List<String> opInputs = op.getInputsToOp();
                            boolean anyInputsRequired = false;
                            for (String s2 : opInputs) {
                                if (!minimalSubgraphVars.contains(s2)) continue;
                                anyInputsRequired = true;
                                break;
                            }
                            if (anyInputsRequired && !differentiatedOps.contains(op.getName())) {
                                isRequiredOp = true;
                            }
                        }
                        if (!isRequiredOp) continue;
                        boolean allAvailable = true;
                        SameDiffOp o = (SameDiffOp)sameDiff.ops.get(opName);
                        for (String opOutput : o.getOutputsOfOp()) {
                            Variable outVar = (Variable)SameDiff.this.variables.get(opOutput);
                            if (!outVar.getVariable().dataType().isFPType() || !minimalSubgraphVars.contains(outVar.getName())) continue;
                            if (outVar.getVariable().gradient() == null) {
                                allAvailable = false;
                                break;
                            }
                            List prereqs = (List)prerequisites.get(outVar.getName());
                            if (prereqs == null || (allAvailable &= differentiatedOps.containsAll(prereqs))) continue;
                            break;
                        }
                        if (!allAvailable || availableForDiff.contains(o.getOp().getOwnName())) continue;
                        availableForDiff.add(o.getOp().getOwnName());
                    }
                }
                for (String s : minimalSubgraphVars) {
                    SDVariable v;
                    SDVariable g;
                    if (SameDiff.this.lossVariables.contains(s) || (g = (v = ((Variable)SameDiff.this.variables.get(s)).getVariable()).gradient()) != null) continue;
                    throw new IllegalStateException("Error encountered during differentiation: no gradient for required variable \"" + s + "\" was calculated");
                }
                return new SDVariable[]{sameDiff.var(SameDiff.GRAD_FN_KEY, DataType.FLOAT, 1)};
            }
        });
        this.associateSameDiffWithOpsAndVariables();
    }

    public void setOriginalPlaceHolderShape(String variableName, long[] shape) {
        if (!this.isPlaceHolder(variableName)) {
            throw new ND4JIllegalStateException("Vertex id " + variableName + " does not appear to be a place holder. Did you forget to call addPlaceHolder?");
        }
        if (shape == null) {
            throw new ND4JIllegalStateException("Null and 0 length shape arrays not allowed");
        }
        if (this.placeHolderOriginalShapes.containsKey(variableName) && !Arrays.equals(this.placeHolderOriginalShapes.get(variableName), shape)) {
            throw new ND4JIllegalStateException("Unable to add a new shape for vertex id " + variableName);
        }
        this.placeHolderOriginalShapes.put(variableName, shape);
    }

    @Deprecated
    public long[] getOriginalShapeForPlaceHolder(String varName) {
        return this.placeHolderOriginalShapes.get(varName);
    }

    public boolean isPlaceHolder(String varName) {
        Preconditions.checkState((boolean)this.variables.containsKey(varName), (String)"No variable present in SameDiff instance with name \"%s\"", (Object)varName);
        return this.variables.get(varName).getVariable().isPlaceHolder();
    }

    public void resolveVariablesWith(Map<String, INDArray> arrays) {
        for (Map.Entry<String, INDArray> e : arrays.entrySet()) {
            long[] newShape;
            SDVariable varForName = this.getVariable(e.getKey());
            if (varForName == null) {
                throw new ND4JIllegalStateException("A placeholder array was provided for variable with name \"" + e.getKey() + "\" but no variable with this name exists");
            }
            Variable v = this.variables.get(e.getKey());
            if (varForName.getVariableType() != VariableType.PLACEHOLDER) continue;
            long[] shape = varForName.placeholderShape();
            Preconditions.checkState((shape.length == (newShape = e.getValue().shape()).length ? 1 : 0) != 0, (String)"Placeholder shape not compatible (mismatched rank): placeholder \"%s\" shape %s, got incompatible shape %s", (Object)e.getKey(), (Object)shape, (Object)newShape);
        }
        for (Map.Entry<String, INDArray> entry : arrays.entrySet()) {
            if (!this.variables.get(entry.getKey()).getVariable().isPlaceHolder()) {
                throw new ND4JIllegalStateException("Illegal variable " + entry.getKey() + " passed in. Variable found not to be a place holder variable");
            }
            long[] specifiedShape = this.getOriginalShapeForPlaceHolder(entry.getKey());
            if (!Shape.isPlaceholderShape(specifiedShape) && !Shape.shapeEquals(specifiedShape, entry.getValue().shape())) {
                throw new ND4JIllegalStateException("Place holder shape specified was " + Arrays.toString(specifiedShape) + " but array shape was " + Arrays.toString(entry.getValue().shape()));
            }
            this.associateArrayWithVariable(entry.getValue(), this.getVariable(entry.getKey()));
            this.setArrayForVariable(entry.getKey(), entry.getValue());
        }
        this.resolvedVariables = true;
    }

    @Override
    public SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) {
        String nameScope;
        if (varToUpdate == null) {
            throw new NullPointerException("Null input: No variable found for updating!");
        }
        if (newVarName != null && (nameScope = this.currentNameScope()) != null && !newVarName.startsWith(nameScope + "/")) {
            newVarName = nameScope + "/" + newVarName;
        }
        if (newVarName != null && this.variables.containsKey(newVarName) && varToUpdate != this.variables.get(newVarName).getVariable()) {
            throw new IllegalStateException("Variable name \"" + newVarName + "\" already exists for a different SDVariable");
        }
        if (newVarName == null && this.variables.containsKey(varToUpdate.getVarName()) && this.variables.get(varToUpdate.getVarName()).getVariable() != varToUpdate) {
            newVarName = this.generateNewVarName(varToUpdate.getVarName(), 0);
        }
        if (newVarName == null || varToUpdate.getVarName().equals(newVarName)) {
            return varToUpdate;
        }
        String oldVarName = varToUpdate.getVarName();
        varToUpdate.setVarName(newVarName);
        this.updateVariableName(oldVarName, newVarName);
        return varToUpdate;
    }

    @Override
    protected SameDiff sd() {
        return this;
    }

    @Override
    public SDVariable[] updateVariableNamesAndReferences(SDVariable[] variablesToUpdate, String[] newVariableNames) {
        int numVariables = variablesToUpdate.length;
        SDVariable[] updatedVariables = new SDVariable[numVariables];
        for (int i = 0; i < numVariables; ++i) {
            SDVariable varToUpdate = variablesToUpdate[i];
            String name = newVariableNames == null ? null : newVariableNames[i];
            updatedVariables[i] = this.updateVariableNameAndReference(varToUpdate, name);
        }
        return updatedVariables;
    }

    protected void associateSameDiffWithOpsAndVariables() {
        for (SDVariable var : this.variableMap().values()) {
            var.setSameDiff(this);
        }
        for (SameDiffOp op : this.ops.values()) {
            SDVariable[] outputs;
            DifferentialFunction df = op.getOp();
            df.setSameDiff(this);
            SDVariable[] args = df.args();
            if (args != null) {
                for (SDVariable arg : args) {
                    arg.setSameDiff(this);
                }
            }
            if ((outputs = df.outputVariables()) == null) continue;
            for (SDVariable out : outputs) {
                out.setSameDiff(this);
            }
        }
    }

    protected int asFlatNode(String name, @NonNull SameDiff scope, @NonNull FlatBufferBuilder bufferBuilder) {
        if (scope == null) {
            throw new NullPointerException("scope is marked @NonNull but is null");
        }
        if (bufferBuilder == null) {
            throw new NullPointerException("bufferBuilder is marked @NonNull but is null");
        }
        int scopeName = bufferBuilder.createString((CharSequence)name);
        int flatNode = FlatNode.createFlatNode(bufferBuilder, scopeName, scopeName, (byte)119, 10L, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0);
        return flatNode;
    }

    public static Pair<String, Integer> parseVariable(@NonNull String varName) {
        if (varName == null) {
            throw new NullPointerException("varName is marked @NonNull but is null");
        }
        if (!varName.contains(":")) {
            return Pair.pairOf((Object)varName, (Object)0);
        }
        String[] split = varName.split(":");
        Integer index = Integer.valueOf(split[split.length - 1]);
        if (split.length == 2) {
            return Pair.pairOf((Object)split[0], (Object)index);
        }
        StringBuilder builder = new StringBuilder();
        for (int e = 0; e < split.length - 1; ++e) {
            builder.append(split[e]);
            if (e >= split.length - 2) continue;
            builder.append(":");
        }
        return Pair.pairOf((Object)builder.toString(), (Object)index);
    }

    public ByteBuffer asFlatBuffers(@NonNull ExecutorConfiguration configuration, boolean includeUpdaterState) {
        if (configuration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        return this.asFlatBuffers(0L, configuration, includeUpdaterState);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public ByteBuffer asFlatBuffers(long graphId, @NonNull ExecutorConfiguration configuration, boolean includeUpdaterState) {
        int flatVariable;
        byte varType;
        if (configuration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        Nd4j.getExecutioner().commit();
        FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(1024);
        AtomicInteger idCounter = new AtomicInteger(0);
        ArrayList<Integer> flatVariables = new ArrayList<Integer>();
        ArrayList flatOffsets = new ArrayList();
        ArrayList<Integer> flatNodes = new ArrayList<Integer>();
        ArrayList<SDVariable> variableList = new ArrayList<SDVariable>(this.variables());
        LinkedHashMap<String, Integer> reverseMap = new LinkedHashMap<String, Integer>();
        LinkedHashMap<String, Integer> forwardMap = new LinkedHashMap<String, Integer>();
        LinkedHashMap<String, Integer> framesMap = new LinkedHashMap<String, Integer>();
        int idx = 0;
        IdentityHashMap<DifferentialFunction, Integer> idxForOps = new IdentityHashMap<DifferentialFunction, Integer>();
        List<SDVariable> allVars = this.variables();
        for (SDVariable sDVariable : allVars) {
            int outputNum;
            int varIdx;
            INDArray arr = sDVariable.getArr();
            log.trace("Exporting variable: [{}]", (Object)sDVariable.getVarName());
            String varName = sDVariable.getVarName();
            if (this.variables.get(varName).getOutputOfOp() != null) {
                DifferentialFunction df = this.ops.get(this.variables.get(varName).getOutputOfOp()).getOp();
                if (!idxForOps.containsKey(df)) {
                    varIdx = idCounter.incrementAndGet();
                    idxForOps.put(df, varIdx);
                } else {
                    varIdx = (Integer)idxForOps.get(df);
                }
                Object[] outNames = df.outputVariablesNames();
                outputNum = ArrayUtils.indexOf((Object[])outNames, (Object)varName);
                Preconditions.checkState((outputNum >= 0 ? 1 : 0) != 0, (String)"Variable name \"%s\" not found in list of outputs: %s", (Object)varName, (Object)outNames);
            } else {
                varIdx = idCounter.incrementAndGet();
                outputNum = 0;
            }
            reverseMap.put(sDVariable.getVarName(), varIdx);
            log.trace("Adding [{}] as [{}]", (Object)sDVariable.getVarName(), (Object)varIdx);
            int shape = 0;
            int name = bufferBuilder.createString((CharSequence)sDVariable.getVarName());
            int array = 0;
            int id = IntPair.createIntPair(bufferBuilder, varIdx, outputNum);
            varType = (byte)sDVariable.getVariableType().ordinal();
            if (sDVariable.isConstant() || sDVariable.isPlaceHolder() || sDVariable.getVariableType() == VariableType.VARIABLE) {
                int n = array = arr == null ? 0 : arr.toFlatArray(bufferBuilder);
            }
            if (sDVariable.getVariableType() == VariableType.PLACEHOLDER) {
                long[] shp = sDVariable.getShape();
                shape = FlatVariable.createShapeVector(bufferBuilder, shp);
            }
            flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(sDVariable.dataType()), shape, array, -1, varType);
            flatVariables.add(flatVariable);
        }
        for (SameDiffOp sameDiffOp : this.ops.values()) {
            DifferentialFunction func = sameDiffOp.getOp();
            Iterator<SameDiffOp> fnId = (Integer)idxForOps.get(func);
            flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, variableList, reverseMap, forwardMap, framesMap, idCounter, (Integer)((Object)fnId)));
        }
        for (Map.Entry entry : this.sameDiffFunctionInstances.entrySet()) {
            if (((String)entry.getKey()).equalsIgnoreCase(GRAD_FN_KEY)) continue;
            flatNodes.add(this.asFlatNode((String)entry.getKey(), (SameDiff)entry.getValue(), bufferBuilder));
            ArrayList<SDVariable> currVarList = new ArrayList<SDVariable>(((SameDiff)entry.getValue()).variables());
            for (SDVariable node : ((SameDiff)entry.getValue()).variables()) {
                INDArray arr = node.getArr();
                if (arr == null) continue;
                int name = bufferBuilder.createString((CharSequence)node.getVarName());
                int array = arr.toFlatArray(bufferBuilder);
                int id = IntPair.createIntPair(bufferBuilder, ++idx, 0);
                Pair<String, Integer> pair = SameDiff.parseVariable(node.getVarName());
                reverseMap.put((String)pair.getFirst(), idx);
                log.trace("Adding [{}] as [{}]", pair.getFirst(), (Object)idx);
                varType = (byte)node.getVariableType().ordinal();
                flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(arr.dataType()), 0, array, -1, varType);
                flatVariables.add(flatVariable);
            }
            for (SameDiffOp op : ((SameDiff)entry.getValue()).ops.values()) {
                DifferentialFunction func = op.getOp();
                flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, currVarList, reverseMap, forwardMap, framesMap, idCounter, null));
            }
        }
        int outputsOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatOffsets));
        int n = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatVariables));
        int nodesOffset = FlatGraph.createNodesVector(bufferBuilder, Ints.toArray(flatNodes));
        int numPlaceholders = 0;
        for (SDVariable v : this.variables()) {
            if (!v.isPlaceHolder()) continue;
            ++numPlaceholders;
        }
        int[] placeholderOffsets = new int[numPlaceholders];
        if (numPlaceholders > 0) {
            int i = 0;
            for (SDVariable v : this.variables()) {
                if (!v.isPlaceHolder()) continue;
                placeholderOffsets[i++] = bufferBuilder.createString((CharSequence)v.getVarName());
            }
        }
        int placeholdersOffset = FlatGraph.createPlaceholdersVector(bufferBuilder, placeholderOffsets);
        List<String> lossVars = this.getLossVariables();
        int[] lossVarOffsets = new int[lossVars == null ? 0 : lossVars.size()];
        for (int i = 0; i < lossVarOffsets.length; ++i) {
            lossVarOffsets[i] = bufferBuilder.createString((CharSequence)lossVars.get(i));
        }
        int lossVarOffset = FlatGraph.createLossVariablesVector(bufferBuilder, lossVarOffsets);
        int trainingConfigOffset = 0;
        int updaterStateOffset = 0;
        if (this.trainingConfig != null) {
            String json = this.trainingConfig.toJson();
            trainingConfigOffset = bufferBuilder.createString((CharSequence)json);
        }
        if (includeUpdaterState && this.updaterMap != null && !this.updaterMap.isEmpty()) {
            int[] updaterOffsets = new int[this.updaterMap.size()];
            int updaterNum = 0;
            for (Map.Entry<String, GradientUpdater> entry : this.updaterMap.entrySet()) {
                int paramNameOffset = bufferBuilder.createString((CharSequence)entry.getKey());
                int stateKeyOffset = 0;
                int stateValuesOffset = 0;
                Map<String, INDArray> state = entry.getValue().getState();
                if (state != null && !state.isEmpty()) {
                    int[] keysOffsets = new int[state.size()];
                    int[] valuesOffsets = new int[state.size()];
                    int i = 0;
                    for (Map.Entry<String, INDArray> e : state.entrySet()) {
                        keysOffsets[i] = bufferBuilder.createString((CharSequence)e.getKey());
                        valuesOffsets[i] = e.getValue().toFlatArray(bufferBuilder);
                        ++i;
                    }
                    stateKeyOffset = UpdaterState.createUpdaterStateKeysVector(bufferBuilder, keysOffsets);
                    stateValuesOffset = UpdaterState.createUpdaterStateValuesVector(bufferBuilder, valuesOffsets);
                }
                updaterOffsets[updaterNum++] = UpdaterState.createUpdaterState(bufferBuilder, paramNameOffset, stateKeyOffset, stateValuesOffset);
            }
            updaterStateOffset = FlatGraph.createUpdaterStateVector(bufferBuilder, updaterOffsets);
        }
        int fg = FlatGraph.createFlatGraph(bufferBuilder, graphId, n, nodesOffset, outputsOffset, configuration.getFlatConfiguration(bufferBuilder), placeholdersOffset, lossVarOffset, trainingConfigOffset, updaterStateOffset);
        bufferBuilder.finish(fg);
        SameDiff sameDiff = this;
        synchronized (sameDiff) {
            for (Map.Entry<String, Object> entry : reverseMap.entrySet()) {
                this.variables.get(entry.getKey()).setVariableIndex((Integer)entry.getValue());
            }
        }
        return bufferBuilder.dataBuffer();
    }

    public FlatGraph asFlatGraph(boolean includeUpdaterState) {
        return FlatGraph.getRootAsFlatGraph(this.asFlatBuffers(includeUpdaterState));
    }

    public FlatGraph asFlatGraph(long graphId, ExecutorConfiguration configuration, boolean includeUpdaterState) {
        return FlatGraph.getRootAsFlatGraph(this.asFlatBuffers(graphId, configuration, includeUpdaterState));
    }

    public ByteBuffer asFlatBuffers(boolean includeUpdaterState) {
        ExecutorConfiguration configuration = ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).executionMode(ExecutionMode.SEQUENTIAL).profilingMode(OpExecutioner.ProfilingMode.DISABLED).gatherTimings(true).build();
        return this.asFlatBuffers(configuration, includeUpdaterState);
    }

    public void save(@NonNull File file, boolean saveUpdaterState) {
        if (file == null) {
            throw new NullPointerException("file is marked @NonNull but is null");
        }
        try {
            this.asFlatFile(file, saveUpdaterState);
        }
        catch (IOException e) {
            throw new RuntimeException("Error saving SameDiff instance to file", e);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void save(@NonNull OutputStream outputStream, boolean saveUpdater) {
        if (outputStream == null) {
            throw new NullPointerException("outputStream is marked @NonNull but is null");
        }
        File tempFile = ND4JFileUtils.createTempFile((String)"SameDiffFile", (String)"temp");
        try {
            this.save(tempFile, saveUpdater);
            if (!(outputStream instanceof BufferedOutputStream)) {
                outputStream = new BufferedOutputStream(outputStream);
            }
            try (OutputStream os = outputStream;
                 BufferedInputStream is = new BufferedInputStream(new FileInputStream(tempFile));){
                IOUtils.copy((InputStream)is, (OutputStream)os);
            }
            catch (IOException e) {
                throw new RuntimeException("Error writing to output stream (or reading from temp file)", e);
            }
        }
        finally {
            tempFile.delete();
        }
    }

    public static SameDiff load(@NonNull File file, boolean loadUpdaterState) {
        if (file == null) {
            throw new NullPointerException("file is marked @NonNull but is null");
        }
        try {
            return SameDiff.fromFlatFile(file, loadUpdaterState);
        }
        catch (IOException e) {
            throw new RuntimeException("Error loading SameDiff instance from file", e);
        }
    }

    public static SameDiff load(@NonNull InputStream is, boolean loadUpdaterState) {
        if (is == null) {
            throw new NullPointerException("is is marked @NonNull but is null");
        }
        File tempFile = ND4JFileUtils.createTempFile((String)"SameDiffFile", (String)"temp");
        try {
            try (Object os = new BufferedOutputStream(new FileOutputStream(tempFile));){
                IOUtils.copy((InputStream)is, (OutputStream)os);
            }
            os = SameDiff.fromFlatFile(tempFile, loadUpdaterState);
            return os;
        }
        catch (IOException e) {
            throw new RuntimeException("Error loading SameDiff instance from file", e);
        }
        finally {
            tempFile.delete();
        }
    }

    public void asFlatFile(@NonNull File file) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked @NonNull but is null");
        }
        this.asFlatFile(file, true);
    }

    public void asFlatFile(@NonNull File file, boolean withUpdaterState) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked @NonNull but is null");
        }
        ByteBuffer fb = this.asFlatBuffers(withUpdaterState);
        int offset = fb.position();
        byte[] array = fb.array();
        try (FileOutputStream fos = new FileOutputStream(file);
             BufferedOutputStream bos = new BufferedOutputStream(fos);
             DataOutputStream dos = new DataOutputStream(bos);){
            dos.write(array, offset, array.length - offset);
        }
    }

    public void asFlatFile(@NonNull File file, @NonNull ExecutorConfiguration configuration, boolean includeUpdaterState) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked @NonNull but is null");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        ByteBuffer fb = this.asFlatBuffers(configuration, includeUpdaterState);
        int offset = fb.position();
        byte[] array = fb.array();
        try (FileOutputStream fos = new FileOutputStream(file);
             BufferedOutputStream bos = new BufferedOutputStream(fos);
             DataOutputStream dos = new DataOutputStream(bos);){
            dos.write(array, offset, array.length - offset);
        }
    }

    public static SameDiff fromFlatFile(@NonNull File file) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked @NonNull but is null");
        }
        return SameDiff.fromFlatFile(file, true);
    }

    public static SameDiff fromFlatFile(@NonNull File file, boolean loadUpdaterState) throws IOException {
        byte[] bytes;
        if (file == null) {
            throw new NullPointerException("file is marked @NonNull but is null");
        }
        try (BufferedInputStream is = new BufferedInputStream(new FileInputStream(file));){
            bytes = IOUtils.toByteArray((InputStream)is);
        }
        ByteBuffer bbIn = ByteBuffer.wrap(bytes);
        return SameDiff.fromFlatBuffers(bbIn, loadUpdaterState);
    }

    public static SameDiff fromFlatBuffers(ByteBuffer bbIn) throws IOException {
        return SameDiff.fromFlatBuffers(bbIn, true);
    }

    public static SameDiff fromFlatBuffers(ByteBuffer bbIn, boolean loadUpdaterState) throws IOException {
        String tc;
        FlatArray fa;
        FlatGraph fg = FlatGraph.getRootAsFlatGraph(bbIn);
        int numOps = fg.nodesLength();
        int numVars = fg.variablesLength();
        ArrayList<FlatNode> ops = new ArrayList<FlatNode>(numOps);
        for (int i = 0; i < numOps; ++i) {
            ops.add(fg.nodes(i));
        }
        ArrayList<FlatVariable> vars = new ArrayList<FlatVariable>(numVars);
        for (int i = 0; i < numVars; ++i) {
            vars.add(fg.variables(i));
        }
        FlatConfiguration conf = fg.configuration();
        SameDiff sd = SameDiff.create();
        int numPlaceholders = fg.placeholdersLength();
        LinkedHashSet<String> ph = new LinkedHashSet<String>();
        for (int i = 0; i < numPlaceholders; ++i) {
            ph.add(fg.placeholders(i));
        }
        HashMap varNodeIds = new HashMap();
        HashMap<Pair, SDVariable> variablesByNodeAndOutNum = new HashMap<Pair, SDVariable>();
        HashMap variablesByName = new HashMap();
        for (FlatVariable v : vars) {
            int shapeLength = v.shapeLength();
            long[] shape = new long[shapeLength];
            for (int i = 0; i < shapeLength; ++i) {
                shape[i] = v.shape(i);
            }
            String n = v.name();
            byte dtypeByte = v.dtype();
            DataType dtype = FlatBuffersMapper.getDataTypeFromByte(dtypeByte);
            VariableType vt = VariableType.values()[v.variabletype()];
            SDVariable var = new SDVariable(n, vt, sd, shape, dtype, null);
            sd.variables.put(n, Variable.builder().name(n).variable(var).build());
            sd.variableNameToShape.put(n, shape);
            fa = v.ndarray();
            if (fa != null && vt != VariableType.ARRAY) {
                INDArray arr;
                try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                    arr = Nd4j.createFromFlatArray(fa);
                }
                sd.setArrayForVariable(n, arr);
            }
            IntPair id = v.id();
            variablesByNodeAndOutNum.put(new Pair((Object)id.first(), (Object)id.second()), var);
            if (!variablesByName.containsKey(n)) {
                variablesByName.put(n, new ArrayList());
            }
            List list = (List)variablesByName.get(n);
            list.add(var);
        }
        for (FlatNode fn : ops) {
            int i;
            DifferentialFunction df = FlatBuffersMapper.fromFlatNode(fn);
            String name = fn.name();
            df.setSameDiff(sd);
            df.setOwnName(name);
            if (sd.ops.containsKey(name)) {
                sd.ops.get(name).setOp(df);
            } else {
                sd.ops.put(name, SameDiffOp.builder().name(name).op(df).build());
            }
            int outLength = fn.outputLength();
            int[] outs = new int[outLength];
            for (int i2 = 0; i2 < outLength; ++i2) {
                outs[i2] = fn.output(i2);
            }
            int opId = fn.id();
            int[] output = new int[fn.outputLength()];
            for (int i3 = 0; i3 < output.length; ++i3) {
                output[i3] = fn.output(i3);
            }
            int[] input = new int[fn.inputLength()];
            for (int i4 = 0; i4 < input.length; ++i4) {
                input[i4] = fn.input(i4);
            }
            IntPair[] inputPaired = new IntPair[fn.inputPairedLength()];
            ArrayList<Pair> intPairList = new ArrayList<Pair>();
            for (int i5 = 0; i5 < inputPaired.length; ++i5) {
                inputPaired[i5] = fn.inputPaired(i5);
                intPairList.add(new Pair((Object)inputPaired[i5].first(), (Object)inputPaired[i5].second()));
            }
            String[] inputNames = new String[inputPaired.length];
            for (int i6 = 0; i6 < inputPaired.length; ++i6) {
                int nodeId = inputPaired[i6].first();
                int nodeOutNum = inputPaired[i6].second();
                SDVariable varIn = (SDVariable)variablesByNodeAndOutNum.get(new Pair((Object)nodeId, (Object)nodeOutNum));
                if (varIn == null) {
                    // empty if block
                }
                inputNames[i6] = varIn.getVarName();
            }
            sd.ops.get(df.getOwnName()).setInputsToOp(Arrays.asList(inputNames));
            for (String inName : inputNames) {
                Variable v = sd.getVariables().get(inName);
                if (v.getInputsForOp() == null) {
                    v.setInputsForOp(new ArrayList<String>());
                }
                if (v.getInputsForOp().contains(df.getOwnName())) continue;
                v.getInputsForOp().add(df.getOwnName());
            }
            List varsForOp = (List)variablesByName.get(name);
            int numOutputs = df.getNumOutputs();
            if (numOutputs <= 0) {
                numOutputs = fn.outputLength();
            }
            String[] varNames = null;
            if (varsForOp != null && varsForOp.size() == numOutputs) {
                varNames = new String[varsForOp.size()];
                for (i = 0; i < varNames.length; ++i) {
                    varNames[i] = ((SDVariable)varsForOp.get(i)).getVarName();
                    sd.getVariables().get(varNames[i]).setOutputOfOp(df.getOwnName());
                }
                sd.ops.get(df.getOwnName()).setOutputsOfOp(Arrays.asList(varNames));
            } else {
                int outputNamesLength = fn.outputNamesLength();
                varNames = new String[outputNamesLength];
                for (int i7 = 0; i7 < outputNamesLength; ++i7) {
                    String n;
                    varNames[i7] = n = fn.outputNames(i7);
                    if (!sd.variables.containsKey(n)) {
                        SDVariable var = new SDVariable(n, VariableType.VARIABLE, sd, null, null, null);
                        sd.variables.put(n, Variable.builder().name(n).variable(var).build());
                        variablesByNodeAndOutNum.put(new Pair((Object)opId, (Object)i7), var);
                    }
                    sd.getVariables().get(varNames[i7]).setOutputOfOp(df.getOwnName());
                }
                sd.ops.get(df.getOwnName()).setOutputsOfOp(Arrays.asList(varNames));
            }
            for (i = 0; i < varNames.length; ++i) {
                Pair p = new Pair((Object)opId, (Object)i);
                if (variablesByNodeAndOutNum.containsKey(p)) continue;
                variablesByNodeAndOutNum.put(p, sd.getVariable(varNames[i]));
            }
        }
        if (fg.lossVariablesLength() > 0) {
            for (int i = 0; i < fg.lossVariablesLength(); ++i) {
                sd.addLossVariable(fg.lossVariables(i));
            }
        }
        if ((tc = fg.trainingConfig()) != null) {
            sd.trainingConfig = TrainingConfig.fromJson(tc);
        }
        if (loadUpdaterState && fg.updaterStateLength() > 0) {
            sd.updaterMap = new HashMap<String, GradientUpdater>();
            int n = fg.updaterStateLength();
            for (int i = 0; i < n; ++i) {
                UpdaterState us = fg.updaterState(i);
                String name = us.paramName();
                int nKeys = us.updaterStateKeysLength();
                HashMap<String, INDArray> m = new HashMap<String, INDArray>();
                for (int j = 0; j < nKeys; ++j) {
                    String key = us.updaterStateKeys(j);
                    fa = us.updaterStateValues(j);
                    INDArray stateArr = Nd4j.createFromFlatArray(fa);
                    m.put(key, stateArr);
                }
                GradientUpdater gu = sd.trainingConfig.getUpdater().instantiate(m, false);
                sd.updaterMap.put(name, gu);
            }
            sd.initializedTraining = true;
        }
        return sd;
    }

    public String asFlatPrint() {
        StringBuilder sb = new StringBuilder();
        ByteBuffer fb = this.asFlatBuffers(false);
        FlatGraph graph = FlatGraph.getRootAsFlatGraph(fb);
        sb.append("\nExternal variables:\n\n");
        for (int e = 0; e < graph.variablesLength(); ++e) {
            FlatVariable var = graph.variables(e);
            INDArray ndarray = null;
            try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                FlatArray fa = var.ndarray();
                if (fa != null) {
                    ndarray = Nd4j.createFromFlatArray(fa);
                }
            }
            sb.append(var.id().first()).append(":<").append(var.name()).append("> ");
            if (ndarray == null) {
                sb.append("<no array>").append("; Values: ").append("<no array>").append(";\n");
                continue;
            }
            sb.append(Arrays.toString(ndarray.shapeInfoDataBuffer().asInt())).append("; Values: ");
            if (ndarray.data() == null) {
                sb.append("<empty array>");
            } else if (ndarray.dataType() == DataType.UTF8) {
                sb.append("<string array>");
            } else if (ndarray.length() < 50L) {
                sb.append(Arrays.toString(ndarray.data().asFloat()).replaceAll(" ", ""));
            } else {
                sb.append("[");
                for (int i = 0; i < 50; ++i) {
                    if (i > 0) {
                        sb.append(",");
                    }
                    sb.append(ndarray.data().getFloat((long)i));
                }
                sb.append("]");
            }
            sb.append(";\n");
        }
        Map<String, CustomOpDescriptor> map = Nd4j.getExecutioner().getCustomOperations();
        sb.append("\nOps sequence:\n\n");
        for (int e = 0; e < graph.nodesLength(); ++e) {
            FlatNode node = graph.nodes(e);
            log.info("{}:<{}>", (Object)node.id(), (Object)node.name());
            sb.append(node.id()).append(":<").append(node.name()).append("> ").append((Object)FlatBuffersMapper.getTypeFromByte(node.opType()));
            if (FlatBuffersMapper.getTypeFromByte(node.opType()) != Op.Type.CUSTOM) {
                sb.append(": ").append(node.opNum());
            } else {
                Set<String> keys = map.keySet();
                String opName = null;
                for (String k : keys) {
                    CustomOpDescriptor d = map.get(k);
                    if (d.getHash() != node.opNum()) continue;
                    opName = k;
                }
                if (opName == null) {
                    opName = "unknown";
                }
                sb.append(": ").append(opName);
            }
            sb.append("; Inputs: {");
            for (int i = 0; i < node.inputPairedLength(); ++i) {
                IntPair pair = node.inputPaired(i);
                sb.append("[").append(pair.first()).append(":").append(pair.second()).append("]");
                if (i >= node.inputPairedLength() - 1) continue;
                sb.append(", ");
            }
            sb.append("};");
            sb.append(" OpNum: {").append(node.opNum()).append("};");
            sb.append("\n");
        }
        return sb.toString();
    }

    public String summary() {
        Map<String, SDVariable> varMap = this.variableMap();
        DifferentialFunction[] functions = this.ops();
        int countVarsWithArrays = 0;
        for (String s : varMap.keySet()) {
            if (this.getArrForVarName(s) == null) continue;
            ++countVarsWithArrays;
        }
        StringBuilder sb = new StringBuilder();
        String format = "%-25s%-20s";
        sb.append("--- Summary ---\n");
        sb.append(String.format(format, "Variables:", varMap.size())).append(" (").append(countVarsWithArrays).append(" with arrays)").append("\n").append(String.format(format, "Functions:", functions.length)).append("\n").append(String.format(format, "SameDiff Function Defs:", this.sameDiffFunctionInstances.size())).append("\n").append("Loss function variables: ").append(this.getLossVariables()).append("\n\n");
        sb.append("--- Variables ---\n");
        HashMap<String, String> outputOfFn = new HashMap<String, String>();
        int maxLengthOutputOf = 22;
        int maxLengthOfName = 8;
        for (String s : varMap.keySet()) {
            String outputOf = null;
            for (SameDiffOp op : this.ops.values()) {
                List<String> outputsOfOp = op.getOutputsOfOp();
                if (outputsOfOp == null || !outputsOfOp.contains(s)) continue;
                outputOf = op.getName();
                break;
            }
            if (outputOf == null) {
                outputOf = "<none>";
            } else {
                DifferentialFunction d = this.getOpById(outputOf);
                outputOf = d.getOwnName() + "(" + d.opName() + ")";
            }
            outputOfFn.put(s, outputOf);
            maxLengthOutputOf = Math.max(maxLengthOutputOf, outputOf.length());
            maxLengthOfName = Math.max(maxLengthOfName, s.length());
        }
        format = "%-" + (maxLengthOfName += 2) + "s%-20s%-20s%-20s%-" + (maxLengthOutputOf += 2) + "s%-20s";
        sb.append(String.format(format, "- Name -", "- Array Shape -", "- Variable Type -", "- Data Type-", "- Output Of Function -", "- Inputs To Functions -")).append("\n");
        for (String s : varMap.keySet()) {
            SDVariable v;
            long[] phShape;
            INDArray arr = this.getArrForVarName(s);
            String arrayShape = "-";
            if (arr != null) {
                arrayShape = Arrays.toString(arr.shape());
            } else if (varMap.get(s).isPlaceHolder() && (phShape = (v = varMap.get(s)).placeholderShape()) != null) {
                arrayShape = Arrays.toString(phShape);
            }
            String varType = this.getVariable(s).getVariableType().toString();
            String dtype = this.getVariable(s).dataType().toString();
            List<String> argNames = this.variables.get(s).getInputsForOp();
            String dfArrStr = "";
            if (argNames != null) {
                dfArrStr = argNames.toString();
            }
            String outputOfStr = (String)outputOfFn.get(s);
            sb.append(String.format(format, s, arrayShape, varType, dtype, outputOfStr, dfArrStr)).append("\n");
        }
        sb.append("\n\n--- Functions ---\n");
        ArrayList<String> dfInputStr = new ArrayList<String>();
        ArrayList<String> dfOutputStr = new ArrayList<String>();
        int maxInLength = 10;
        int maxOutLength = 11;
        int maxOpNameLength = 17;
        int maxDfClassNameLength = 10;
        for (DifferentialFunction df : functions) {
            Object[] argNames = df.argNames();
            Object[] outNames = df.outputVariablesNames();
            String argStr = Arrays.toString(argNames);
            String outStr = Arrays.toString(outNames);
            maxInLength = Math.max(maxInLength, argStr.length());
            maxOutLength = Math.max(maxOutLength, outStr.length());
            dfInputStr.add(argStr);
            dfOutputStr.add(outStr);
            String name = df.getOwnName() == null ? df.opName() : df.getOwnName();
            maxOpNameLength = Math.max(maxOpNameLength, name.length());
            maxDfClassNameLength = Math.max(maxDfClassNameLength, df.getClass().getSimpleName().length());
        }
        format = "%-5s%-" + (maxOpNameLength += 2) + "s%-" + (maxDfClassNameLength += 2) + "s%-" + (maxInLength += 2) + "s%-" + (maxOutLength += 2) + "s";
        sb.append(String.format(format, "", "- Function Name -", "- Op -", "- Inputs -", "- Outputs -")).append("\n");
        for (int i = 0; i < functions.length; ++i) {
            DifferentialFunction df = functions[i];
            String fnName = df.getOwnName() == null ? df.opName() : df.getOwnName();
            sb.append(String.format(format, String.valueOf(i), fnName, df.getClass().getSimpleName(), dfInputStr.get(i), dfOutputStr.get(i))).append("\n");
        }
        if (this.sameDiffFunctionInstances.size() > 0) {
            sb.append("\n\n--- SameDiff Defined Functions ---\n");
            format = "%-20s%-15s%-15s%-15s";
            sb.append(String.format(format, "- Name -", "- Variables -", "- Functions -", "- Fn Defs -")).append("\n");
            for (Map.Entry<String, SameDiff> e : this.sameDiffFunctionInstances.entrySet()) {
                SameDiff sd = e.getValue();
                int vars = sd.variableMap().size();
                int fns = sd.ops() == null ? 0 : sd.ops().length;
                int defFns = sd.definedFunctionNames().size();
                sb.append(String.format(format, e.getKey(), String.valueOf(vars), String.valueOf(fns), String.valueOf(defFns))).append("\n");
            }
        }
        return sb.toString();
    }

    public Map<String, DataType> calculateOutputDataTypes() {
        return this.calculateOutputDataTypes(false);
    }

    public Map<String, DataType> calculateOutputDataTypes(boolean dynamicUpdate) {
        ArrayList<String> allVars = new ArrayList<String>(this.variables.keySet());
        DataTypesSession session = new DataTypesSession(this, dynamicUpdate);
        HashMap<String, DataType> phValues = new HashMap<String, DataType>();
        for (Variable v : this.variables.values()) {
            if (!v.getVariable().isPlaceHolder()) continue;
            DataType dt = v.getVariable().dataType();
            Preconditions.checkNotNull((Object)dt, (String)"Placeholder variable %s has null datatype", (Object)v.getName());
            phValues.put(v.getName(), dt);
        }
        Map<String, DataType> out = session.output(allVars, phValues, null, Collections.emptyList(), Collections.emptyList(), At.defaultAt(Operation.INFERENCE));
        return out;
    }

    public String newBlockName(String baseName) {
        if (baseName == null) {
            return null;
        }
        if (!this.blockNames.contains(baseName)) {
            this.blockNames.add(baseName);
            return baseName;
        }
        int i = 1;
        while (this.blockNames.contains(baseName + "_" + i)) {
            ++i;
        }
        this.blockNames.add(baseName + "_" + i);
        return baseName + "_" + i;
    }

    public static SameDiff importFrozenTF(File graphFile) {
        return TFGraphMapper.getInstance().importGraph(graphFile);
    }

    public static SameDiff importFrozenTF(GraphDef graphDef) {
        return TFGraphMapper.getInstance().importGraph(graphDef);
    }

    public static SameDiff importFrozenTF(InputStream graph) {
        return TFGraphMapper.getInstance().importGraph(graph);
    }

    public String getOpName(String base, boolean force) {
        Matcher num;
        base = this.nameWithScope(base);
        if (force && this.ops.containsKey(base)) {
            throw new IllegalArgumentException("Op with name \"" + base + "\" already exists");
        }
        if (force) {
            return base;
        }
        int start = 1;
        if (base.contains("_") && (num = Pattern.compile("(.*)_(\\d+)").matcher(base)).find()) {
            start = Integer.parseInt(num.group(2));
            base = num.group(1);
        }
        String name = base;
        int i = start;
        while (true) {
            boolean varWithName = false;
            for (String varName : this.variables.keySet()) {
                if (!varName.startsWith(name + ":") && !varName.equals(name)) continue;
                varWithName = true;
            }
            if (!this.ops.containsKey(name) && !varWithName) break;
            name = base + "_" + i;
            ++i;
        }
        return name;
    }

    public String getOpName(String base) {
        return this.getOpName(base, false);
    }

    public String generateNewVarName(String base, int argIndex, boolean existingOp) {
        Matcher num;
        base = this.nameWithScope(base);
        if (argIndex > 0 && base.contains(":") && (num = Pattern.compile("(.*):(\\d+)").matcher(base)).find()) {
            argIndex = Integer.parseInt(num.group(2)) + 1;
            base = num.group(1);
        }
        if (!existingOp) {
            base = this.getOpName(base);
        }
        if (argIndex > 0) {
            base = base + ":" + argIndex;
        }
        if (this.variables.containsKey(base)) {
            throw new IllegalArgumentException("Variable with name \"" + base + "\" already exists");
        }
        return base;
    }

    @Override
    public String generateNewVarName(String base, int argIndex) {
        return this.generateNewVarName(base, argIndex, true);
    }

    public String generateDistinctCustomVariableName(String base) {
        if (!this.variables.containsKey(base)) {
            return base;
        }
        int inc = 1;
        while (this.variables.containsKey(base + "_" + inc)) {
            ++inc;
        }
        return base + "_" + inc;
    }

    public static SameDiffBuilder builder() {
        return new SameDiffBuilder();
    }

    public SameDiff(TrainingConfig trainingConfig, boolean initializedTraining, Map<String, GradientUpdater> updaterMap, Map<String, String> baseNameForFunctionInstanceId, DifferentialFunctionFactory functionFactory, Map<String, long[]> variableNameToShape, Map<String, SDVariable> forwardVarForGrad, int variableId, Map<String, List<String>> propertiesToResolve, Map<String, Map<String, Object>> propertiesForFunction, Map<String, long[]> placeHolderOriginalShapes, Map<String, SameDiffFunctionDefinition> sameDiffFunctionDefinitionMap, Map<String, SameDiff> sameDiffFunctionInstances, Set<String> placeHolderFunctions, Table<String, String, String> fieldVariableResolutionMapping, AtomicBoolean wasRegistered, boolean debugMode, Map<int[], Op> opsForResult, boolean resolvedVariables, Stack<ArgumentInterceptor> argumentInterceptors, Set<ArgumentInterceptor> pausedArgumentInterceptors, Set<String> blockNames, boolean logExecution, SameDiff parent, SameDiff child) {
        this.trainingConfig = trainingConfig;
        this.initializedTraining = initializedTraining;
        this.updaterMap = updaterMap;
        this.baseNameForFunctionInstanceId = baseNameForFunctionInstanceId;
        this.functionFactory = functionFactory;
        this.variableNameToShape = variableNameToShape;
        this.forwardVarForGrad = forwardVarForGrad;
        this.variableId = variableId;
        this.propertiesToResolve = propertiesToResolve;
        this.propertiesForFunction = propertiesForFunction;
        this.placeHolderOriginalShapes = placeHolderOriginalShapes;
        this.sameDiffFunctionDefinitionMap = sameDiffFunctionDefinitionMap;
        this.sameDiffFunctionInstances = sameDiffFunctionInstances;
        this.placeHolderFunctions = placeHolderFunctions;
        this.fieldVariableResolutionMapping = fieldVariableResolutionMapping;
        this.wasRegistered = wasRegistered;
        this.debugMode = debugMode;
        this.opsForResult = opsForResult;
        this.resolvedVariables = resolvedVariables;
        this.argumentInterceptors = argumentInterceptors;
        this.pausedArgumentInterceptors = pausedArgumentInterceptors;
        this.blockNames = blockNames;
        this.logExecution = logExecution;
        this.parent = parent;
        this.child = child;
    }

    public Map<String, Variable> getVariables() {
        return this.variables;
    }

    public Map<String, SameDiffOp> getOps() {
        return this.ops;
    }

    public Map<Long, InferenceSession> getSessions() {
        return this.sessions;
    }

    public TrainingConfig getTrainingConfig() {
        return this.trainingConfig;
    }

    public boolean isInitializedTraining() {
        return this.initializedTraining;
    }

    public Map<String, GradientUpdater> getUpdaterMap() {
        return this.updaterMap;
    }

    public boolean isDebugMode() {
        return this.debugMode;
    }

    public Stack<ArgumentInterceptor> getArgumentInterceptors() {
        return this.argumentInterceptors;
    }

    public Set<ArgumentInterceptor> getPausedArgumentInterceptors() {
        return this.pausedArgumentInterceptors;
    }

    public boolean isLogExecution() {
        return this.logExecution;
    }

    public void setLogExecution(boolean logExecution) {
        this.logExecution = logExecution;
    }

    public SameDiff getParent() {
        return this.parent;
    }

    public SameDiff getChild() {
        return this.child;
    }

    static {
        Method[] methods;
        log = LoggerFactory.getLogger(SameDiff.class);
        opMethods = new HashMap<String, Method>();
        for (Method method : methods = SameDiff.class.getDeclaredMethods()) {
            if (!method.getReturnType().equals(SDVariable.class)) continue;
            opMethods.put(method.getName(), method);
        }
    }

    public static class SameDiffBuilder {
        private TrainingConfig trainingConfig;
        private boolean initializedTraining;
        private Map<String, GradientUpdater> updaterMap;
        private Map<String, String> baseNameForFunctionInstanceId;
        private DifferentialFunctionFactory functionFactory;
        private Map<String, long[]> variableNameToShape;
        private Map<String, SDVariable> forwardVarForGrad;
        private int variableId;
        private Map<String, List<String>> propertiesToResolve;
        private Map<String, Map<String, Object>> propertiesForFunction;
        private Map<String, long[]> placeHolderOriginalShapes;
        private Map<String, SameDiffFunctionDefinition> sameDiffFunctionDefinitionMap;
        private Map<String, SameDiff> sameDiffFunctionInstances;
        private Set<String> placeHolderFunctions;
        private Table<String, String, String> fieldVariableResolutionMapping;
        private AtomicBoolean wasRegistered;
        private boolean debugMode;
        private Map<int[], Op> opsForResult;
        private boolean resolvedVariables;
        private Stack<ArgumentInterceptor> argumentInterceptors;
        private Set<ArgumentInterceptor> pausedArgumentInterceptors;
        private Set<String> blockNames;
        private boolean logExecution;
        private SameDiff parent;
        private SameDiff child;

        SameDiffBuilder() {
        }

        public SameDiffBuilder trainingConfig(TrainingConfig trainingConfig) {
            this.trainingConfig = trainingConfig;
            return this;
        }

        public SameDiffBuilder initializedTraining(boolean initializedTraining) {
            this.initializedTraining = initializedTraining;
            return this;
        }

        public SameDiffBuilder updaterMap(Map<String, GradientUpdater> updaterMap) {
            this.updaterMap = updaterMap;
            return this;
        }

        public SameDiffBuilder baseNameForFunctionInstanceId(Map<String, String> baseNameForFunctionInstanceId) {
            this.baseNameForFunctionInstanceId = baseNameForFunctionInstanceId;
            return this;
        }

        public SameDiffBuilder functionFactory(DifferentialFunctionFactory functionFactory) {
            this.functionFactory = functionFactory;
            return this;
        }

        @Deprecated
        public SameDiffBuilder variableNameToShape(Map<String, long[]> variableNameToShape) {
            this.variableNameToShape = variableNameToShape;
            return this;
        }

        @Deprecated
        public SameDiffBuilder forwardVarForGrad(Map<String, SDVariable> forwardVarForGrad) {
            this.forwardVarForGrad = forwardVarForGrad;
            return this;
        }

        public SameDiffBuilder variableId(int variableId) {
            this.variableId = variableId;
            return this;
        }

        public SameDiffBuilder propertiesToResolve(Map<String, List<String>> propertiesToResolve) {
            this.propertiesToResolve = propertiesToResolve;
            return this;
        }

        public SameDiffBuilder propertiesForFunction(Map<String, Map<String, Object>> propertiesForFunction) {
            this.propertiesForFunction = propertiesForFunction;
            return this;
        }

        @Deprecated
        public SameDiffBuilder placeHolderOriginalShapes(Map<String, long[]> placeHolderOriginalShapes) {
            this.placeHolderOriginalShapes = placeHolderOriginalShapes;
            return this;
        }

        public SameDiffBuilder sameDiffFunctionDefinitionMap(Map<String, SameDiffFunctionDefinition> sameDiffFunctionDefinitionMap) {
            this.sameDiffFunctionDefinitionMap = sameDiffFunctionDefinitionMap;
            return this;
        }

        public SameDiffBuilder sameDiffFunctionInstances(Map<String, SameDiff> sameDiffFunctionInstances) {
            this.sameDiffFunctionInstances = sameDiffFunctionInstances;
            return this;
        }

        public SameDiffBuilder placeHolderFunctions(Set<String> placeHolderFunctions) {
            this.placeHolderFunctions = placeHolderFunctions;
            return this;
        }

        public SameDiffBuilder fieldVariableResolutionMapping(Table<String, String, String> fieldVariableResolutionMapping) {
            this.fieldVariableResolutionMapping = fieldVariableResolutionMapping;
            return this;
        }

        public SameDiffBuilder wasRegistered(AtomicBoolean wasRegistered) {
            this.wasRegistered = wasRegistered;
            return this;
        }

        public SameDiffBuilder debugMode(boolean debugMode) {
            this.debugMode = debugMode;
            return this;
        }

        public SameDiffBuilder opsForResult(Map<int[], Op> opsForResult) {
            this.opsForResult = opsForResult;
            return this;
        }

        public SameDiffBuilder resolvedVariables(boolean resolvedVariables) {
            this.resolvedVariables = resolvedVariables;
            return this;
        }

        public SameDiffBuilder argumentInterceptors(Stack<ArgumentInterceptor> argumentInterceptors) {
            this.argumentInterceptors = argumentInterceptors;
            return this;
        }

        public SameDiffBuilder pausedArgumentInterceptors(Set<ArgumentInterceptor> pausedArgumentInterceptors) {
            this.pausedArgumentInterceptors = pausedArgumentInterceptors;
            return this;
        }

        public SameDiffBuilder blockNames(Set<String> blockNames) {
            this.blockNames = blockNames;
            return this;
        }

        public SameDiffBuilder logExecution(boolean logExecution) {
            this.logExecution = logExecution;
            return this;
        }

        public SameDiffBuilder parent(SameDiff parent) {
            this.parent = parent;
            return this;
        }

        public SameDiffBuilder child(SameDiff child) {
            this.child = child;
            return this;
        }

        public SameDiff build() {
            return new SameDiff(this.trainingConfig, this.initializedTraining, this.updaterMap, this.baseNameForFunctionInstanceId, this.functionFactory, this.variableNameToShape, this.forwardVarForGrad, this.variableId, this.propertiesToResolve, this.propertiesForFunction, this.placeHolderOriginalShapes, this.sameDiffFunctionDefinitionMap, this.sameDiffFunctionInstances, this.placeHolderFunctions, this.fieldVariableResolutionMapping, this.wasRegistered, this.debugMode, this.opsForResult, this.resolvedVariables, this.argumentInterceptors, this.pausedArgumentInterceptors, this.blockNames, this.logExecution, this.parent, this.child);
        }

        public String toString() {
            return "SameDiff.SameDiffBuilder(trainingConfig=" + this.trainingConfig + ", initializedTraining=" + this.initializedTraining + ", updaterMap=" + this.updaterMap + ", baseNameForFunctionInstanceId=" + this.baseNameForFunctionInstanceId + ", functionFactory=" + this.functionFactory + ", variableNameToShape=" + this.variableNameToShape + ", forwardVarForGrad=" + this.forwardVarForGrad + ", variableId=" + this.variableId + ", propertiesToResolve=" + this.propertiesToResolve + ", propertiesForFunction=" + this.propertiesForFunction + ", placeHolderOriginalShapes=" + this.placeHolderOriginalShapes + ", sameDiffFunctionDefinitionMap=" + this.sameDiffFunctionDefinitionMap + ", sameDiffFunctionInstances=" + this.sameDiffFunctionInstances + ", placeHolderFunctions=" + this.placeHolderFunctions + ", fieldVariableResolutionMapping=" + this.fieldVariableResolutionMapping + ", wasRegistered=" + this.wasRegistered + ", debugMode=" + this.debugMode + ", opsForResult=" + this.opsForResult + ", resolvedVariables=" + this.resolvedVariables + ", argumentInterceptors=" + this.argumentInterceptors + ", pausedArgumentInterceptors=" + this.pausedArgumentInterceptors + ", blockNames=" + this.blockNames + ", logExecution=" + this.logExecution + ", parent=" + this.parent + ", child=" + this.child + ")";
        }
    }
}

