/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.internal;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.If;
import org.nd4j.linalg.api.ops.impl.controlflow.While;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayConcat;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayGather;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayRead;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayScatter;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArraySize;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArraySplit;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayWrite;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InferenceSession
extends AbstractSession<INDArray, DifferentialFunction> {
    private static final Logger log = LoggerFactory.getLogger(InferenceSession.class);
    private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\nAlternatively, arrays defined in a workspace must be replaced after the workspace has been closed.";

    public InferenceSession(@NonNull SameDiff sameDiff) {
        super(sameDiff);
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked @NonNull but is null");
        }
    }

    @Override
    protected Map<String, INDArray> preprocessPlaceholders(Map<String, INDArray> placeholders) {
        if (placeholders == null || placeholders.isEmpty()) {
            return placeholders;
        }
        HashMap<String, INDArray> out = new HashMap<String, INDArray>();
        for (Map.Entry<String, INDArray> e : placeholders.entrySet()) {
            Preconditions.checkState((boolean)this.sameDiff.hasVariable(e.getKey()), (String)"Invalid placeholder passed for execution: No variable/placeholder with name %s exists", (Object)e.getKey());
            INDArray arr = e.getValue();
            if (arr.isAttached()) {
                MemoryWorkspace ws;
                MemoryWorkspace memoryWorkspace = ws = arr.data() == null ? null : arr.data().getParentWorkspace();
                if (ws != null && ws.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) {
                    if (!ws.isScopeActive()) {
                        throw new ND4JIllegalStateException("Placeholder \"" + e.getKey() + "\" array uses leaked workspace pointer from workspace [" + ws.getId() + "]: Workspace the array was defined in is no longer open.\nAll open workspaces: " + DefaultOpExecutioner.allOpenWorkspaces() + "\n" + SCOPE_PANIC_MSG);
                    }
                    if (ws.getGenerationId() != arr.data().getGenerationId()) {
                        throw new ND4JIllegalStateException("Placeholder \"" + e.getKey() + "\" array uses outdated workspace pointer from workspace [" + ws.getId() + "]: Workspace array was defined in has been closed and reopened at least once since array creation. Array WS iteration: " + arr.data().getGenerationId() + ". Workspace current iteration: " + ws.getGenerationId() + "\nAll open workspaces: " + DefaultOpExecutioner.allOpenWorkspaces() + "\n" + SCOPE_PANIC_MSG);
                    }
                }
            }
            DataType dt = this.sameDiff.getVariable(e.getKey()).dataType();
            if (arr.dataType() != dt) {
                arr = arr.castTo(dt);
            }
            out.put(e.getKey(), arr);
        }
        return out;
    }

    public INDArray[] getOutputs(DifferentialFunction op, AbstractSession.FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch) {
        if (listeners != null && listeners.size() > 0) {
            SameDiffOp sdOp = this.sameDiff.getOps().get(op.getOwnName());
            for (Listener l : listeners) {
                if (!l.isActive(at.operation())) continue;
                l.preOpExecution(this.sameDiff, at, sdOp);
            }
        }
        INDArray[] out = this.getOutputsHelper(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs);
        if (listeners != null && listeners.size() > 0) {
            SameDiffOp sdOp = this.sameDiff.getOps().get(op.getOwnName());
            HashMap<String, INDArray> namedOutsBuilder = new HashMap<String, INDArray>();
            for (int i = 0; i < out.length; ++i) {
                namedOutsBuilder.put(sdOp.outputsOfOp.get(i), out[i]);
            }
            Map namedOuts = Collections.unmodifiableMap(namedOutsBuilder);
            for (Listener l : listeners) {
                if (!l.isActive(at.operation())) continue;
                l.opExecution(this.sameDiff, at, batch, sdOp, out);
                for (String varName : namedOuts.keySet()) {
                    l.activationAvailable(this.sameDiff, at, batch, sdOp, varName, (INDArray)namedOuts.get(varName));
                }
            }
        }
        return out;
    }

    public INDArray[] getOutputsHelper(DifferentialFunction op, AbstractSession.FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs) {
        boolean constPhInput;
        int totalInputs = (opInputs == null ? 0 : opInputs.size()) + (constAndPhInputs == null ? 0 : constAndPhInputs.size()) + (allIterInputs == null ? 0 : allIterInputs.size());
        boolean bl = constPhInput = !(opInputs != null && opInputs.size() != 0 || allIterInputs != null && allIterInputs.size() != 0);
        if (op instanceof Identity) {
            Identity i = (Identity)op;
            Object[] argNames = i.argNames();
            Preconditions.checkState((argNames.length == 1 ? 1 : 0) != 0, (String)"Expected only 1 arg name in identity op, got %s", (Object[])argNames);
            AbstractSession.VarId vid = this.newVarId((String)argNames[0], outputFrameIter);
            return new INDArray[]{(INDArray)this.nodeOutputs.get(vid)};
        }
        if (op instanceof Switch) {
            Switch s = (Switch)op;
            String[] argNames = s.argNames();
            AbstractSession.VarId vidPredicate = this.newVarId(argNames[1], outputFrameIter);
            INDArray predicate = (INDArray)this.nodeOutputs.get(vidPredicate);
            Preconditions.checkState((predicate.isScalar() && predicate.dataType() == DataType.BOOL ? 1 : 0) != 0, (String)"Expected boolean predicate: got %ndSInfo", (Object)predicate);
            AbstractSession.VarId vid = this.newVarId(argNames[0], outputFrameIter);
            if (predicate.getDouble(0L) == 0.0) {
                return new INDArray[]{(INDArray)this.nodeOutputs.get(vid), null};
            }
            return new INDArray[]{null, (INDArray)this.nodeOutputs.get(vid)};
        }
        if (op instanceof Enter) {
            Enter e = (Enter)op;
            Object[] input = e.argNames();
            Preconditions.checkState((input.length == 1 ? 1 : 0) != 0, (String)"Expected only 1 arg name for enter op: got %s", (Object[])input);
            Preconditions.checkState((totalInputs == 1 ? 1 : 0) != 0, (String)"Expected exactly 1 op input for Enter op \"%s\", got %s+%s", (Object)e.getOwnName(), opInputs, constAndPhInputs);
            AbstractSession.VarId inputVarId = constPhInput ? new AbstractSession.VarId(constAndPhInputs.iterator().next(), "main", 0, null) : (allIterInputs != null && allIterInputs.size() > 0 ? allIterInputs.iterator().next() : opInputs.iterator().next());
            INDArray enterInput = (INDArray)this.nodeOutputs.get(inputVarId);
            Preconditions.checkNotNull((Object)enterInput, (String)"Could not get enter op \"%s\" input: output variable %s - %s", (Object)e.getOwnName(), (Object)e.outputVariablesNames(), (Object)outputFrameIter);
            return new INDArray[]{enterInput};
        }
        if (op instanceof Exit) {
            AbstractSession.VarId inputVarId = constPhInput ? new AbstractSession.VarId(constAndPhInputs.iterator().next(), "main", 0, null) : (allIterInputs != null && allIterInputs.size() > 0 ? allIterInputs.iterator().next() : opInputs.iterator().next());
            INDArray exitInput = (INDArray)this.nodeOutputs.get(inputVarId);
            return new INDArray[]{exitInput};
        }
        if (op instanceof NextIteration) {
            Preconditions.checkState((totalInputs == 1 ? 1 : 0) != 0, (String)"Expected exactly 1 op input for NextIteration: got %s+%s", opInputs, constAndPhInputs);
            AbstractSession.VarId in = allIterInputs != null && !allIterInputs.isEmpty() ? allIterInputs.iterator().next() : opInputs.iterator().next();
            Preconditions.checkState((boolean)outputFrameIter.getFrame().equals(in.getFrame()), (String)"Expected same frame for NextIteration input vs. output: got input %s, output %s", (Object)in, (Object)outputFrameIter);
            Preconditions.checkState((outputFrameIter.getIteration() == in.getIteration() + 1 ? 1 : 0) != 0, (String)"Expected output iteration for NextIteration output to be 1 larger than the input iteration. Input: %s, output %s", (Object)in, (Object)outputFrameIter);
            INDArray inArr = (INDArray)this.nodeOutputs.get(in);
            return new INDArray[]{inArr};
        }
        if (op instanceof If) {
            If i = (If)op;
            String[] argNames = i.argNames();
            throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName());
        }
        if (op instanceof Merge) {
            Object[] in;
            Merge m = (Merge)op;
            for (String string : in = this.sameDiff.getInputsForOp(op)) {
                AbstractSession.VarId vid = this.newVarId(string, outputFrameIter);
                if (!this.nodeOutputs.containsKey(vid)) continue;
                log.trace("Returning input \"{}\" for merge node \"{}\"", (Object)m.getOwnName(), (Object)string);
                return new INDArray[]{(INDArray)this.nodeOutputs.get(vid)};
            }
            throw new IllegalStateException("Merge node " + m.getOwnName() + " has no available inputs (all inputs: " + Arrays.toString(in) + ") - should not be executed at this point");
        }
        if (op instanceof LoopCond) {
            LoopCond lc = (LoopCond)op;
            Object[] argNames = lc.argNames();
            Preconditions.checkState((argNames.length == 1 ? 1 : 0) != 0, (String)"Expected only 1 arg name in LoopCond op, got %s", (Object[])argNames);
            AbstractSession.VarId vid = this.newVarId((String)argNames[0], outputFrameIter);
            INDArray arr = (INDArray)this.nodeOutputs.get(vid);
            Preconditions.checkNotNull((Object)arr, (String)"Input to LoopCond op must not be null");
            Preconditions.checkState((arr.isScalar() && arr.dataType() == DataType.BOOL ? 1 : 0) != 0, (String)"LoopCond input must be a scalar boolean, got %ndShape");
            return new INDArray[]{arr};
        }
        if (op instanceof BaseTensorOp) {
            if (op instanceof TensorArray) {
                AbstractSession.VarId vid = this.newVarId(op.outputVariable().getVarName(), outputFrameIter);
                Preconditions.checkState((!this.tensorArrays.containsKey(vid) ? 1 : 0) != 0, (String)"TensorArray already exists for %s when executing TensorArrayV3", (Object)vid);
                this.tensorArrays.put(vid, new ArrayList());
                try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    INDArray[] arr = new INDArray[]{Nd4j.scalar(true), Nd4j.scalar(0.0f)};
                    return arr;
                }
            }
            if (op instanceof TensorArrayRead) {
                AbstractSession.VarId v;
                SDVariable idxSDV = op.arg(1);
                INDArray idxArr = this.getArray(idxSDV, opInputs, allIterInputs);
                Preconditions.checkState((boolean)idxArr.isScalar(), (String)"TensorArrayRead input argument 1 should be scalar - has shape %ndShape", (Object)idxArr);
                int i = idxArr.getInt(0);
                SDVariable inTensorArray = op.arg(0);
                AbstractSession.VarId varId = v = opInputs == null ? null : InferenceSession.lookup(inTensorArray.getVarName(), opInputs, false);
                if (v == null && allIterInputs != null) {
                    v = InferenceSession.lookup(inTensorArray.getVarName(), allIterInputs, false);
                }
                Preconditions.checkState((v != null ? 1 : 0) != 0, (String)"Could not find input %s", (Object)inTensorArray.getVarName());
                while (this.sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter) {
                    inTensorArray = this.sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg();
                    v = this.newVarId(inTensorArray.getVarName(), v.getParentFrame());
                }
                List list = this.getTensorArrays().get(v);
                Preconditions.checkState((list != null ? 1 : 0) != 0, (String)"Could not find TensorList for %s", (Object)v);
                Preconditions.checkState((list.size() > i ? 1 : 0) != 0, (String)"Cannot get index %s from TensorList of size %s (array not present?) - VarId=%s", (Object)i, (Object)list.size(), (Object)v);
                INDArray out = (INDArray)list.get(i);
                return new INDArray[]{out};
            }
            if (op instanceof TensorArrayWrite) {
                AbstractSession.VarId tArr;
                SDVariable inTensorArray = op.arg(0);
                AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.getVarName(), opInputs, false);
                if (tArr == null && allIterInputs != null) {
                    tArr = InferenceSession.lookup(inTensorArray.getVarName(), allIterInputs, false);
                }
                Preconditions.checkState((tArr != null ? 1 : 0) != 0, (String)"Could not find input %s", (Object)inTensorArray.getVarName());
                while (this.sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter) {
                    inTensorArray = this.sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg();
                    tArr = this.newVarId(inTensorArray.getVarName(), tArr.getParentFrame());
                }
                String idxName = op.arg(1).getVarName();
                SDVariable idxSDV = this.sameDiff.getVariable(idxName);
                INDArray idxArr = this.getArray(idxSDV, opInputs, allIterInputs);
                Preconditions.checkState((boolean)idxArr.isScalar(), (String)"Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", (Object)idxArr);
                int n = idxArr.getInt(0);
                String inName = op.arg(2).getVarName();
                SDVariable inSDV = this.sameDiff.getVariable(inName);
                INDArray arr = this.getArray(inSDV, opInputs, allIterInputs);
                Preconditions.checkState((arr != null ? 1 : 0) != 0, (String)"Could not find array for %s", (Object)inName);
                Preconditions.checkState((boolean)this.tensorArrays.containsKey(tArr), (String)"Tensor array does not exist for %s", (Object)tArr);
                List l = (List)this.tensorArrays.get(tArr);
                while (l.size() <= n) {
                    l.add(null);
                }
                l.set(n, arr);
                try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    INDArray[] iNDArrayArray = new INDArray[]{Nd4j.scalar(0.0f)};
                    return iNDArrayArray;
                }
            }
            if (op instanceof TensorArraySize) {
                List l;
                AbstractSession.VarId tArr;
                SDVariable inTensorArray = op.arg(0);
                AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.getVarName(), opInputs, false);
                if (tArr == null && allIterInputs != null) {
                    tArr = InferenceSession.lookup(inTensorArray.getVarName(), allIterInputs, false);
                }
                Preconditions.checkState(((l = (List)this.tensorArrays.get(tArr)) != null ? 1 : 0) != 0, (String)"Could not find TensorArray: %s", (Object)tArr);
                try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    INDArray[] iNDArrayArray = new INDArray[]{Nd4j.scalar(DataType.INT, l.size())};
                    return iNDArrayArray;
                }
            }
            if (op instanceof TensorArrayConcat) {
                AbstractSession.VarId tArr;
                SDVariable inTensorArray = op.arg(0);
                AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.getVarName(), opInputs, false);
                if (tArr == null && allIterInputs != null) {
                    tArr = InferenceSession.lookup(inTensorArray.getVarName(), allIterInputs, false);
                }
                List l = (List)this.tensorArrays.get(tArr);
                try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    INDArray iNDArray = Nd4j.concat(0, l.toArray(new INDArray[l.size()]));
                    INDArray[] inName = new INDArray[]{iNDArray};
                    return inName;
                }
            }
            if (op instanceof TensorArrayGather) {
                List l;
                AbstractSession.VarId tArr;
                SDVariable inTensorArray = op.arg(0);
                AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.getVarName(), opInputs, false);
                if (tArr == null && allIterInputs != null) {
                    tArr = InferenceSession.lookup(inTensorArray.getVarName(), allIterInputs, false);
                }
                Preconditions.checkState(((l = (List)this.tensorArrays.get(tArr)) != null ? 1 : 0) != 0, (String)"Could not find TensorArray: %s", (Object)tArr);
                String indicesName = op.arg(1).getVarName();
                SDVariable indicesSDV = this.sameDiff.getVariable(indicesName);
                INDArray iNDArray = this.getArray(indicesSDV, opInputs, allIterInputs);
                Preconditions.checkState((boolean)iNDArray.isVector(), (String)"Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", (Object)iNDArray, (Object)indicesName);
                Preconditions.checkState((boolean)iNDArray.dataType().isIntType(), (String)"Indices variable for TensorArrayGather should be an integer type, got %s for array %s", (Object)iNDArray.dataType(), (Object)indicesName);
                int[] idxArrInt = iNDArray.toIntVector();
                ArrayList<INDArray> newList = new ArrayList<INDArray>();
                if (idxArrInt.length == 1 && idxArrInt[0] == -1) {
                    newList.addAll(l);
                } else {
                    for (int id : idxArrInt) {
                        Preconditions.checkState((id >= 0 ? 1 : 0) != 0, (String)"Index for TensorArrayGather must be >= 0, got %s", (int)id);
                        newList.add((INDArray)l.get(id));
                    }
                }
                try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    INDArray out = Nd4j.pile(newList);
                    INDArray[] id = new INDArray[]{out};
                    return id;
                }
            }
            if (op instanceof TensorArrayScatter) {
                List l;
                AbstractSession.VarId tArr;
                SDVariable inTensorArray = op.arg(0);
                TensorArray ta = (TensorArray)this.sameDiff.getVariableOutputOp(inTensorArray.getVarName());
                AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.getVarName(), opInputs, false);
                if (tArr == null && allIterInputs != null) {
                    tArr = InferenceSession.lookup(inTensorArray.getVarName(), allIterInputs, false);
                }
                Preconditions.checkState(((l = (List)this.tensorArrays.get(tArr)) != null ? 1 : 0) != 0, (String)"Could not find TensorArray: %s", (Object)tArr);
                String indicesName = op.arg(1).getVarName();
                SDVariable sDVariable = this.sameDiff.getVariable(indicesName);
                INDArray idxArr = this.getArray(sDVariable, opInputs, allIterInputs);
                Preconditions.checkState((boolean)idxArr.isVector(), (String)"Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", (Object)idxArr, (Object)indicesName);
                Preconditions.checkState((boolean)idxArr.dataType().isIntType(), (String)"Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", (Object)idxArr.dataType(), (Object)indicesName);
                int[] idxs = idxArr.toIntVector();
                String valuesName = op.arg(2).getVarName();
                SDVariable valuesSDV = this.sameDiff.getVariable(valuesName);
                INDArray valuesArr = this.getArray(valuesSDV, opInputs, allIterInputs);
                while (l.size() <= idxs.length) {
                    l.add(null);
                }
                if (idxs.length == 1 && idxs[0] == -1) {
                    idxs = ArrayUtil.range((int)0, (int)((int)valuesArr.size(0)));
                }
                INDArrayIndex[] idx = (INDArrayIndex[])ArrayUtil.nTimes((int)valuesArr.rank(), (Object)NDArrayIndex.all(), INDArrayIndex.class);
                for (int i = 0; i < idxs.length; ++i) {
                    idx[0] = NDArrayIndex.point(i);
                    INDArray get = valuesArr.get(idx).dup();
                    int outIdx = idxs[i];
                    if (valuesArr.rank() == 2 && get.rank() == 2) {
                        get = get.reshape(get.length());
                    }
                    if (valuesArr.rank() == 1 && get.rank() > 0) {
                        get = get.reshape(new long[0]);
                    }
                    l.set(outIdx, get);
                }
                try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    INDArray[] iNDArrayArray = new INDArray[]{Nd4j.scalar(0.0f)};
                    return iNDArrayArray;
                }
            }
            if (op instanceof TensorArraySplit) {
                List l;
                AbstractSession.VarId tArr;
                SDVariable inTensorArray = op.arg(0);
                AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.getVarName(), opInputs, false);
                if (tArr == null && allIterInputs != null) {
                    tArr = InferenceSession.lookup(inTensorArray.getVarName(), allIterInputs, false);
                }
                Preconditions.checkState(((l = (List)this.tensorArrays.get(tArr)) != null ? 1 : 0) != 0, (String)"Could not find TensorArray: %s", (Object)tArr);
                String splitName = op.arg(1).getVarName();
                INDArray splitArr = this.getArray(this.sameDiff.getVariable(splitName), opInputs, allIterInputs);
                String string = op.arg(2).getVarName();
                SDVariable sizeSDV = this.sameDiff.getVariable(string);
                INDArray sizeArr = this.getArray(sizeSDV, opInputs, allIterInputs);
                Preconditions.checkState((boolean)sizeArr.isVector(), (String)"Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", (Object)sizeArr, (Object)string);
                Preconditions.checkState((boolean)sizeArr.dataType().isIntType(), (String)"Indices variable for TensorArraySplit should be an integer type, got %s for array %s", (Object)sizeArr.dataType(), (Object)string);
                int[] sizes = sizeArr.toIntVector();
                while (l.size() <= sizes.length) {
                    l.add(null);
                }
                INDArrayIndex[] idx = (INDArrayIndex[])ArrayUtil.nTimes((int)splitArr.rank(), (Object)NDArrayIndex.all(), INDArrayIndex.class);
                int soFar = 0;
                for (int i = 0; i < sizes.length; ++i) {
                    idx[0] = NDArrayIndex.interval(soFar, soFar + sizes[i]);
                    INDArray sub = splitArr.get(idx).dup();
                    l.set(i, sub);
                    soFar += sizes[i];
                }
                try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    INDArray[] iNDArrayArray = new INDArray[]{Nd4j.scalar(0.0f)};
                    return iNDArrayArray;
                }
            }
            throw new IllegalStateException("Execution support not yet implemented for: " + op.getClass().getName());
        }
        if (op instanceof GradientBackwardsMarker) {
            return new INDArray[]{Nd4j.scalar(1.0f)};
        }
        if (op instanceof CustomOp) {
            CustomOp c = (CustomOp)((Object)op);
            Nd4j.getExecutioner().exec(c);
            return c.outputArguments();
        }
        if (op instanceof Op) {
            Op o = (Op)((Object)op);
            Nd4j.getExecutioner().exec(o);
            return new INDArray[]{o.z()};
        }
        throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName());
    }

    @Override
    public INDArray getConstantOrVariable(String variableName) {
        SDVariable v = this.sameDiff.getVariable(variableName);
        Preconditions.checkState((this.sameDiff.getVariable(variableName).isConstant() || v.getVariableType() == VariableType.VARIABLE ? 1 : 0) != 0, (String)"Variable %s is not a constant", (Object)variableName);
        return this.sameDiff.getArrForVarName(variableName);
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public DifferentialFunction getAndParameterizeOp(String opName, AbstractSession.FrameIter frameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, Map<String, INDArray> placeholderValues) {
        boolean isLoop;
        DifferentialFunction df = this.sameDiff.getOpById(opName);
        Preconditions.checkNotNull((Object)df, (String)"No differential function fond with name %s", (Object)opName);
        if (df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration || df instanceof Merge || df instanceof Switch || df instanceof If || df instanceof While || df instanceof BaseTensorOp) {
            return df;
        }
        String[] argNames = df.argNames();
        int numArgs = argNames == null ? 0 : argNames.length;
        int numNonConstIns = opInputs == null ? 0 : opInputs.size();
        int numNonConstInsAllIters = allIterInputs == null ? 0 : allIterInputs.size();
        int numConstPhIns = constAndPhInputs == null ? 0 : constAndPhInputs.size();
        HashSet<String> constEnterInputs = null;
        if (numArgs != numNonConstIns + numConstPhIns + numNonConstInsAllIters) {
            SDVariable[] args;
            boolean anyConstEnterInputs = false;
            for (SDVariable v : args = df.args()) {
                DifferentialFunction differentialFunction;
                Variable var = this.sameDiff.getVariables().get(v.getVarName());
                DifferentialFunction differentialFunction2 = differentialFunction = var.getOutputOfOp() == null ? null : this.sameDiff.getOps().get(var.getOutputOfOp()).getOp();
                if (!(differentialFunction instanceof Enter) || !((Enter)differentialFunction).isConstant()) continue;
                anyConstEnterInputs = true;
                if (constEnterInputs == null) {
                    constEnterInputs = new HashSet<String>();
                }
                constEnterInputs.add(v.getVarName());
            }
            int constEnterInputCount = 0;
            if (anyConstEnterInputs) {
                for (String s : constEnterInputs) {
                    if (constAndPhInputs != null && constAndPhInputs.contains(s)) continue;
                    boolean found = false;
                    if (allIterInputs != null) {
                        for (AbstractSession.VarId varId : allIterInputs) {
                            if (!s.equals(varId.getVariable())) continue;
                            found = true;
                            break;
                        }
                    }
                    if (found) continue;
                    ++constEnterInputCount;
                }
            }
            if (numArgs > 1) {
                HashSet uniqueArgNames = new HashSet();
                Collections.addAll(uniqueArgNames, argNames);
                Preconditions.checkState((uniqueArgNames.size() == numNonConstIns + numConstPhIns + numNonConstInsAllIters + constEnterInputCount ? 1 : 0) != 0, (String)"Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", (Object)df.getClass().getSimpleName(), (Object)opName, uniqueArgNames, opInputs, constAndPhInputs);
            } else {
                Preconditions.checkState((numArgs == numNonConstIns + numConstPhIns + constEnterInputCount ? 1 : 0) != 0, (String)"Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", (Object)df.getClass().getSimpleName(), (Object)opName, (Object)argNames, opInputs, constAndPhInputs);
            }
        }
        INDArray[] args = null;
        if (argNames != null && argNames.length > 0) {
            args = new INDArray[argNames.length];
            int i = 0;
            for (String s : argNames) {
                SDVariable v = this.sameDiff.getVariable(s);
                if (v.isConstant()) {
                    args[i] = v.getArr();
                } else if (v.isPlaceHolder()) {
                    Preconditions.checkState((placeholderValues != null && placeholderValues.containsKey(s) ? 1 : 0) != 0, (String)"No array provided for placeholder %s", (Object)s);
                    args[i] = placeholderValues.get(s);
                } else if (constEnterInputs != null && constEnterInputs.contains(s)) {
                    INDArray arr;
                    AbstractSession.VarId varId = this.newVarId(s, frameIter.clone());
                    varId.setIteration(0);
                    for (AbstractSession.FrameIter toZero = varId.getParentFrame(); toZero != null; toZero = toZero.getParentFrame()) {
                        toZero.setIteration(0);
                    }
                    args[i] = arr = (INDArray)this.nodeOutputs.get(varId);
                } else {
                    if (opInputs != null) {
                        for (AbstractSession.VarId vid : opInputs) {
                            if (!vid.getVariable().equals(s)) continue;
                            args[i] = (INDArray)this.nodeOutputs.get(vid);
                            break;
                        }
                    }
                    if (args[i] == null && allIterInputs != null) {
                        for (AbstractSession.VarId vid : allIterInputs) {
                            if (!vid.getVariable().equals(s)) continue;
                            args[i] = (INDArray)this.nodeOutputs.get(vid);
                            break;
                        }
                    }
                }
                Preconditions.checkNotNull((Object)args[i], (String)"Could not parameterize op %s: array %s (variable %s) is null", (Object)opName, (Object)i, (Object)v.getVarName());
                ++i;
            }
        }
        boolean bl = isLoop = !frameIter.getFrame().equals("main") && frameIter.getIteration() > 0;
        if (df instanceof CustomOp) {
            DynamicCustomOp customOp = (DynamicCustomOp)df;
            if (args != null) {
                customOp.setInputArguments(args);
            }
            df.resolvePropertiesFromSameDiffBeforeExecution();
            List<LongShapeDescriptor> outShape = customOp.calculateOutputShape();
            Preconditions.checkState((outShape != null && outShape.size() > 0 ? 1 : 0) != 0, (String)"Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", (Object)customOp.opName(), (Object)customOp.getOwnName());
            String[] outNames = df.outputVariablesNames();
            Preconditions.checkState((outNames.length == outShape.size() ? 1 : 0) != 0, (String)"Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation with %s outputs (number of shapes and outputs must be equal)", (Object)df.opName(), (Object)outShape.size(), (Object)outNames.length);
            for (int i = 0; i < outShape.size(); ++i) {
                INDArray out;
                void var21_52;
                DataType currDT;
                INDArray currOutput = customOp.numOutputArguments() <= i ? null : customOp.getOutputArgument(i);
                LongShapeDescriptor longShapeDescriptor = outShape.get(i);
                DataType dt = this.sameDiff.getVariable(outNames[i]).dataType();
                if (dt != (currDT = longShapeDescriptor.dataType())) {
                    LongShapeDescriptor longShapeDescriptor2 = longShapeDescriptor.asDataType(dt);
                }
                if (currOutput != null && currOutput.shapeDescriptor().equals(var21_52) && currOutput.isEmpty() == var21_52.isEmpty() && !isLoop) continue;
                try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    out = Nd4j.create((LongShapeDescriptor)var21_52, false);
                }
                customOp.setOutputArgument(i, out);
            }
        } else if (df instanceof Op) {
            Op op = (Op)((Object)df);
            boolean axisArg = false;
            boolean emptyReduce = false;
            if (op instanceof ReduceOp && ((ReduceOp)op).getOpType() != Op.Type.REDUCE3 && df.argNames().length == 2) {
                SDVariable axisArgVar = df.arg(1);
                Preconditions.checkState((boolean)axisArgVar.dataType().isIntType(), (String)"Legacy op %s input 1 (axis) was expected to be an integer type, is %s", df.getClass(), (Object)axisArgVar.dataType());
                INDArray arr = this.getArray(axisArgVar, opInputs, allIterInputs);
                Preconditions.checkState((arr != null ? 1 : 0) != 0, (String)"Could not get axis argument for op %s: %s", (Object)df.getOwnName(), df.getClass());
                if (!arr.isEmpty()) {
                    int[] nArray = arr.toIntVector();
                    int rank = args[0].rank();
                    int[] nArray2 = Shape.normalizeAxis(rank, nArray);
                    df.setDimensions(nArray2);
                    ((BaseReduceOp)op).setEmptyReduce(false);
                } else {
                    df.setDimensions(null);
                    emptyReduce = true;
                    ((BaseReduceOp)op).setEmptyReduce(true);
                }
                axisArg = true;
            } else if (op instanceof ScalarOp && df.argNames().length == 2) {
                SDVariable scalarVar = df.arg(1);
                INDArray scalar = this.getArray(scalarVar, opInputs, allIterInputs);
                Preconditions.checkState((scalar != null ? 1 : 0) != 0, (String)"Could not get scalar argument for op %s: %s", (Object)df.getOwnName(), df.getClass());
                Preconditions.checkState((boolean)scalar.isScalar(), (String)"Scalar argument for op %s (%s) is not a scalar: has shape %ndShape", (Object)df.getOwnName(), df.getClass(), (Object)scalar);
                ((ScalarOp)op).setScalar(scalar);
            }
            if (args != null && args.length > 0) {
                op.setX(args[0]);
                if (args.length == 2 && !axisArg) {
                    op.setY(args[1]);
                }
            }
            if (emptyReduce) {
                INDArray z = op.z();
                if (z == null || !op.x().equalShapes(z) || isLoop) {
                    op.setZ(op.x().ulike());
                }
            } else {
                List<LongShapeDescriptor> outputShape = ((BaseOp)op).calculateOutputShape();
                Preconditions.checkState((outputShape != null && outputShape.size() == 1 ? 1 : 0) != 0, (String)"Could not calculate output shape for op: %s", op.getClass());
                INDArray z = op.z();
                if (z == null || !outputShape.get(0).equals(z.shapeDescriptor()) || isLoop) {
                    if (log.isTraceEnabled()) {
                        log.trace("Existing op result (z) array shape for op {} was {}, allocating new array of shape {}", new Object[]{op.getClass().getSimpleName(), z == null ? null : Arrays.toString(z.shape()), outputShape.get(0).toString()});
                    }
                    LongShapeDescriptor longShapeDescriptor = outputShape.get(0);
                    try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                        z = Nd4j.create(longShapeDescriptor, false);
                    }
                    op.setZ(z);
                }
            }
            df.resolvePropertiesFromSameDiffBeforeExecution();
        }
        return df;
    }

    protected INDArray getArray(SDVariable sdv, Collection<AbstractSession.VarId> opInputs, Collection<AbstractSession.VarId> allIterInputs) {
        String n = sdv.getVarName();
        if (sdv.getVariableType() == VariableType.CONSTANT || sdv.getVariableType() == VariableType.VARIABLE) {
            return this.getConstantOrVariable(n);
        }
        AbstractSession.VarId inVarId = null;
        if (opInputs != null) {
            inVarId = InferenceSession.lookup(n, opInputs, false);
        }
        if (inVarId == null && allIterInputs != null && !allIterInputs.isEmpty()) {
            inVarId = InferenceSession.lookup(n, allIterInputs, false);
        }
        Preconditions.checkState((inVarId != null ? 1 : 0) != 0, (String)"Could not find array for variable %s", (Object)sdv.getVarName());
        return (INDArray)this.nodeOutputs.get(inVarId);
    }
}

