/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops;

import com.google.common.collect.Lists;
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Longs;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.functions.FunctionProperties;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class DynamicCustomOp
extends DifferentialFunction
implements CustomOp {
    private static final Logger log = LoggerFactory.getLogger(DynamicCustomOp.class);
    private String opName;
    protected List<INDArray> inputArguments;
    protected List<INDArray> outputArguments;
    protected List<Double> tArguments;
    protected List<Long> iArguments;
    protected boolean inplaceCall;
    private long hash;
    protected SDVariable[] outputVariables;
    private List<long[]> outputShapes;

    public DynamicCustomOp() {
        this.inputArguments = new ArrayList<INDArray>();
        this.outputArguments = new ArrayList<INDArray>();
        this.tArguments = new ArrayList<Double>();
        this.iArguments = new ArrayList<Long>();
        this.iArguments = new ArrayList<Long>();
        this.tArguments = new ArrayList<Double>();
    }

    public DynamicCustomOp(String opName, SameDiff sameDiff, SDVariable[] args) {
        super(sameDiff, args);
        this.inputArguments = new ArrayList<INDArray>();
        this.outputArguments = new ArrayList<INDArray>();
        this.tArguments = new ArrayList<Double>();
        this.iArguments = new ArrayList<Long>();
        this.opName = opName;
        this.iArguments = new ArrayList<Long>();
        this.tArguments = new ArrayList<Double>();
    }

    public DynamicCustomOp(String opName, INDArray input, INDArray output, List<Double> tArguments, int[] iArguments) {
        INDArray[] iNDArrayArray;
        INDArray[] iNDArrayArray2;
        if (input == null) {
            iNDArrayArray2 = null;
        } else {
            INDArray[] iNDArrayArray3 = new INDArray[1];
            iNDArrayArray2 = iNDArrayArray3;
            iNDArrayArray3[0] = input;
        }
        if (output == null) {
            iNDArrayArray = null;
        } else {
            INDArray[] iNDArrayArray4 = new INDArray[1];
            iNDArrayArray = iNDArrayArray4;
            iNDArrayArray4[0] = output;
        }
        this(opName, iNDArrayArray2, iNDArrayArray, tArguments, iArguments);
    }

    public DynamicCustomOp(String opName, INDArray[] inputs, INDArray[] outputs, List<Double> tArguments, int[] iArguments) {
        this(opName, inputs, outputs, tArguments, ArrayUtil.toList((int[])iArguments));
    }

    public DynamicCustomOp(String opName, INDArray[] inputs, INDArray[] outputs, List<Double> tArguments, List<Integer> iArguments) {
        this.inputArguments = new ArrayList<INDArray>();
        this.outputArguments = new ArrayList<INDArray>();
        this.tArguments = new ArrayList<Double>();
        this.iArguments = new ArrayList<Long>();
        if (inputs != null) {
            this.inputArguments = new ArrayList<INDArray>(Arrays.asList(inputs));
        }
        if (outputs != null) {
            this.outputArguments = new ArrayList<INDArray>(Arrays.asList(outputs));
        }
        this.opName = opName;
        this.tArguments = tArguments == null ? new ArrayList<Double>() : tArguments;
        this.iArguments = new ArrayList<Long>();
        if (iArguments != null) {
            for (Integer a : iArguments) {
                this.iArguments.add(a.longValue());
            }
        }
    }

    public DynamicCustomOp(String opName, INDArray[] inputs, INDArray[] outputs) {
        this(opName, inputs, outputs, (List<Double>)Lists.newArrayList(), Lists.newArrayList());
    }

    public DynamicCustomOp(String opName, SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
        super(sameDiff, inPlace, args);
        this.inputArguments = new ArrayList<INDArray>();
        this.outputArguments = new ArrayList<INDArray>();
        this.tArguments = new ArrayList<Double>();
        this.iArguments = new ArrayList<Long>();
        this.opName = opName;
        this.iArguments = new ArrayList<Long>();
        this.tArguments = new ArrayList<Double>();
        this.inplaceCall = inPlace;
    }

    protected DynamicCustomOp(String opName) {
        this.inputArguments = new ArrayList<INDArray>();
        this.outputArguments = new ArrayList<INDArray>();
        this.tArguments = new ArrayList<Double>();
        this.iArguments = new ArrayList<Long>();
        this.opName = opName;
        this.iArguments = new ArrayList<Long>();
        this.tArguments = new ArrayList<Double>();
    }

    @Override
    public String opName() {
        return this.opName;
    }

    @Override
    public SDVariable[] outputVariables() {
        return this.outputVariables(this.getOwnName() != null ? this.getOwnName() : this.opName());
    }

    @Override
    public SDVariable[] outputVariables(String baseName) {
        if (this.outputVariables == null) {
            String[] outputNames = this.sameDiff.getOutputsForFunction(this);
            if (outputNames != null) {
                this.outputVariables = new SDVariable[outputNames.length];
                for (int i = 0; i < this.outputVariables.length; ++i) {
                    this.outputVariables[i] = this.sameDiff.getVariable(outputNames[i]);
                }
                return this.outputVariables;
            }
            SDVariable[] newVars = this.sameDiff.generateOutputVariableForOp(this, baseName);
            if (this.isInplaceCall()) {
                INDArray arr;
                if (this.args().length >= 1 && (arr = this.args()[0].getArr()) != null) {
                    if (this.sameDiff.getArrForVarName(newVars[0].getVarName()) == null) {
                        this.sameDiff.putArrayForVarName(newVars[0].getVarName(), arr);
                    } else {
                        this.sameDiff.updateArrayForVarName(newVars[0].getVarName(), arr);
                    }
                    this.addOutputArgument(arr);
                }
                return newVars;
            }
            List<long[]> outputShapes = this.calculateOutputShape();
            if (newVars != null && outputShapes != null && !outputShapes.isEmpty()) {
                for (int i = 0; i < newVars.length; ++i) {
                    if (newVars[i] == null) continue;
                    this.attemptToGetOrCreateArrForVar(newVars[i], outputShapes.get(i));
                }
            } else {
                if (this.getDescriptor() != null && this.getDescriptor().getNumOutputs() < 1 && this.getNumOutputs() < 1) {
                    return new SDVariable[0];
                }
                this.outputVariables = newVars;
                return newVars;
            }
            this.outputVariables = newVars;
            if (this.sameDiff.getOutputsForFunction(this) == null) {
                this.sameDiff.addOutgoingFor(this.outputVariables, (DifferentialFunction)this);
            }
            return newVars;
        }
        boolean missingArray = false;
        for (SDVariable v : this.outputVariables) {
            if (v.getArr() != null) continue;
            missingArray = true;
            break;
        }
        if (missingArray) {
            List<long[]> shape;
            try {
                shape = this.calculateOutputShape();
            }
            catch (Exception e) {
                throw new RuntimeException("Error calculating shape for op " + this.opName() + " of type " + this.getClass().getSimpleName() + " with name " + this.getOwnName(), e);
            }
            if (shape != null && !shape.isEmpty()) {
                Preconditions.checkState((shape.size() == this.outputVariables.length ? 1 : 0) != 0, (String)"Different number of calculated shapes (%s) vs. number of output variables (%s) - op %s", (Object)shape.size(), (Object)this.outputVariables.length, (Object)this.opName());
                for (int i = 0; i < this.outputVariables.length; ++i) {
                    SDVariable var = this.outputVariables[i];
                    if (var.getShape() != null) continue;
                    this.attemptToGetOrCreateArrForVar(var, shape.get(i));
                }
            }
        }
        return this.outputVariables;
    }

    private INDArray attemptToGetOrCreateArrForVar(SDVariable var, long[] currShape) {
        INDArray arr = null;
        if (Shape.isPlaceholderShape(var.getShape())) {
            if (var.getShape() == null) {
                List<long[]> shape = this.calculateOutputShape();
                if (!shape.isEmpty()) {
                    if (currShape != null && !Shape.isPlaceholderShape(currShape)) {
                        this.sameDiff.putShapeForVarName(var.getVarName(), currShape);
                        arr = var.storeAndAllocateNewArray();
                    }
                } else {
                    arr = null;
                }
            }
        } else if (this.sameDiff.getArrForVarName(var.getVarName()) == null) {
            if (var.getShape() != null) {
                arr = var.storeAndAllocateNewArray();
            }
        } else {
            arr = var.getArr();
        }
        if (arr != null) {
            this.sameDiff.associateArrayWithVariable(arr, var);
            this.addOutputArgument(arr);
        }
        return arr;
    }

    @Override
    public long opHash() {
        if (this.hash == 0L) {
            Map<String, CustomOpDescriptor> map = Nd4j.getExecutioner().getCustomOperations();
            CustomOpDescriptor desc = map.get(this.opName());
            if (desc == null) {
                throw new ND4JIllegalStateException("Op name " + this.opName() + " is missing!");
            }
            this.hash = desc.getHash();
        }
        return this.hash;
    }

    @Override
    public INDArray[] outputArguments() {
        if (!this.outputArguments.isEmpty()) {
            return this.outputArguments.toArray(new INDArray[this.outputArguments.size()]);
        }
        return new INDArray[0];
    }

    @Override
    public INDArray[] inputArguments() {
        if (!this.inputArguments.isEmpty()) {
            return this.inputArguments.toArray(new INDArray[this.inputArguments.size()]);
        }
        return new INDArray[0];
    }

    @Override
    public long[] iArgs() {
        return Longs.toArray(this.iArguments);
    }

    @Override
    public double[] tArgs() {
        return Doubles.toArray(this.tArguments);
    }

    @Override
    public void addIArgument(int ... arg) {
        int[] nArray = arg;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            long a = nArray[i];
            this.iArguments.add(a);
        }
    }

    @Override
    public void addIArgument(long ... arg) {
        for (long a : arg) {
            this.iArguments.add(a);
        }
    }

    private void addIArgument(Integer ... arg) {
        for (Integer a : arg) {
            this.addIArgument(a.longValue());
        }
    }

    @Override
    public void removeIArgument(Integer arg) {
        this.iArguments.remove(arg);
    }

    @Override
    public Long getIArgument(int index) {
        return this.iArguments.get(index);
    }

    @Override
    public int numIArguments() {
        return this.iArguments == null ? 0 : this.iArguments.size();
    }

    @Override
    public void addTArgument(double ... arg) {
        if (arg != null) {
            this.addTArgument(Doubles.asList((double[])arg).toArray(new Double[arg.length]));
        }
    }

    private void addTArgument(Double ... arg) {
        this.tArguments.addAll(Arrays.asList(arg));
    }

    @Override
    public void removeTArgument(Double arg) {
        this.tArguments.remove(arg);
    }

    @Override
    public Double getTArgument(int index) {
        return this.tArguments.get(index);
    }

    @Override
    public int numTArguments() {
        return this.tArguments == null ? 0 : this.tArguments.size();
    }

    @Override
    public void addInputArgument(INDArray ... arg) {
        for (int i = 0; i < arg.length; ++i) {
            if (arg[i] != null) continue;
            throw new ND4JIllegalStateException("Input " + i + " was null!");
        }
        this.inputArguments.addAll(Arrays.asList(arg));
        SDVariable[] args = this.sameDiff != null ? this.args() : null;
        INDArray[] arrsSoFar = this.inputArguments();
        if (args != null) {
            for (int i = 0; i < args.length; ++i) {
                if (i >= arrsSoFar.length || Arrays.equals(args[i].getShape(), arrsSoFar[i].shape())) continue;
                throw new ND4JIllegalStateException("Illegal array passed in as argument [" + i + "]. Expected shape " + Arrays.toString(args[i].getShape()) + " and received array with shape " + Arrays.toString(arg[i].shape()));
            }
        }
    }

    @Override
    public void removeInputArgument(INDArray arg) {
        this.inputArguments.remove(arg);
    }

    @Override
    public INDArray getInputArgument(int index) {
        return this.inputArguments.get(index);
    }

    public void setInputArgument(int index, INDArray input) {
        this.inputArguments.set(index, input);
    }

    public void setOutputArgument(int index, INDArray output) {
        this.outputArguments.set(index, output);
    }

    @Override
    public int numInputArguments() {
        return this.inputArguments.size();
    }

    @Override
    public void addOutputArgument(INDArray ... arg) {
        for (int i = 0; i < arg.length; ++i) {
            if (arg[i] != null) continue;
            throw new ND4JIllegalStateException("Output " + i + " was null!");
        }
        this.outputArguments.addAll(Arrays.asList(arg));
    }

    @Override
    public void removeOutputArgument(INDArray arg) {
        this.outputArguments.remove(arg);
    }

    @Override
    public INDArray getOutputArgument(int index) {
        return this.outputArguments.get(index);
    }

    @Override
    public int numOutputArguments() {
        return this.outputArguments.size();
    }

    @Override
    public int opNum() {
        return (int)this.opHash();
    }

    public static DynamicCustomOpsBuilder builder(String opName) {
        Map<String, CustomOpDescriptor> map;
        String lcName = (map = Nd4j.getExecutioner().getCustomOperations()).containsKey(opName) ? opName : opName.toLowerCase();
        CustomOpDescriptor desc = map.get(lcName);
        if (desc == null) {
            throw new ND4JIllegalStateException("Unknown operations requested: [" + opName + "]");
        }
        return new DynamicCustomOpsBuilder(lcName, desc.getHash(), desc.getNumInputs(), desc.getNumOutputs(), desc.isAllowsInplace(), desc.getNumTArgs(), desc.getNumIArgs());
    }

    @Override
    public List<long[]> calculateOutputShape() {
        CustomOpDescriptor descriptor = this.getDescriptor();
        for (SDVariable arg : this.args()) {
            if (!this.sameDiff.isPlaceHolder(arg.getVarName()) || this.sameDiff.shapeAlreadyExistsForVarName(arg.getVarName())) continue;
            if (log.isTraceEnabled()) {
                log.trace("Could not calculate output shape for op {}: arg \"{}\" is placeholder", (Object)this.getClass().getName(), (Object)arg.getVarName());
            }
            return Collections.emptyList();
        }
        if (this.outputShapes != null) {
            return this.outputShapes;
        }
        if (descriptor == null) {
            throw new IllegalStateException("Could not find descriptor for op: " + this.opName() + (DynamicCustomOp.class == this.getClass() ? "" : " - class: " + this.getClass().getName()));
        }
        if (descriptor.getNumIArgs() >= 0 && this.numIArguments() < descriptor.getNumIArgs()) {
            if (log.isTraceEnabled()) {
                log.trace("Could not calculate output shape for op {}: not fully initialized ({} IArgs specified, {} required)", new Object[]{this.getClass().getName(), this.numIArguments(), descriptor.getNumIArgs()});
            }
            return Collections.emptyList();
        }
        if (descriptor.getNumTArgs() >= 0 && this.numTArguments() < descriptor.getNumTArgs()) {
            if (log.isTraceEnabled()) {
                log.trace("Could not calculate output shape for op {}: not fully initialized ({} TArgs specified, {} required)", new Object[]{this.getClass().getName(), this.numTArguments(), descriptor.getNumTArgs()});
            }
            return Collections.emptyList();
        }
        if (descriptor.getNumInputs() >= 0 && this.numInputArguments() < descriptor.getNumInputs()) {
            if (log.isTraceEnabled()) {
                log.trace("Could not calculate output shape for op {}: not fully initialized ({} input (INDArray) args specified, {} required)", new Object[]{this.getClass().getName(), this.numInputArguments(), descriptor.getNumInputs()});
            }
            return Collections.emptyList();
        }
        return Nd4j.getExecutioner().calculateOutputShape(this);
    }

    @Override
    public CustomOpDescriptor getDescriptor() {
        Map<String, CustomOpDescriptor> map = Nd4j.getExecutioner().getCustomOperations();
        return map.get(this.opName());
    }

    @Override
    public void assertValidForExecution() {
        CustomOpDescriptor descriptor = this.getDescriptor();
        if (descriptor == null) {
            throw new NoOpNameFoundException("No descriptor found for op name " + this.opName());
        }
        if (descriptor.getNumInputs() > 0 && this.numInputArguments() < descriptor.getNumInputs()) {
            throw new ND4JIllegalStateException("Op [" + this.opName() + "] failure for [" + this.getOwnName() + "]: Number of inputs is invalid for execution. Specified [" + this.numInputArguments() + "] but should be [" + descriptor.getNumInputs() + "]");
        }
        if (descriptor.getNumOutputs() > 0 && this.numOutputArguments() < descriptor.getNumOutputs()) {
            throw new ND4JIllegalStateException("Op [" + this.opName() + "] failure for [" + this.getOwnName() + "]: Number of outputs is invalid for execution. Specified [" + this.numOutputArguments() + "] but should be [" + descriptor.getNumOutputs() + "]");
        }
        if (descriptor.getNumIArgs() >= 0 && this.numIArguments() < descriptor.getNumIArgs()) {
            throw new ND4JIllegalStateException("Op [" + this.opName() + "] failure for [" + this.getOwnName() + "]: Number of integer arguments is invalid for execution. Specified [" + this.numIArguments() + "] but should be [" + descriptor.getNumIArgs() + "]");
        }
        if (descriptor.getNumTArgs() >= 0 && this.numTArguments() < descriptor.getNumTArgs()) {
            throw new ND4JIllegalStateException("Op [" + this.opName() + "] failure for [" + this.getOwnName() + "]: Number of inputs is invalid for execution. Specified [" + this.numTArguments() + "] but should be [" + descriptor.getNumTArgs() + "]");
        }
    }

    @Override
    public void populateInputsAndOutputsFromSameDiff() {
        CustomOpDescriptor descriptor = this.getDescriptor();
        if (descriptor == null) {
            throw new ND4JIllegalStateException("No custom op descriptor found for op name \"" + this.opName() + "\"");
        }
        log.debug("Op <{}>, isInplace: {}", (Object)this.opName(), (Object)this.isInplaceCall());
        this.inputArguments.clear();
        boolean nullArr = false;
        for (SDVariable arg : this.args()) {
            if (arg.getArr() != null) continue;
            nullArr = true;
            log.warn("No input found for " + arg.getVarName() + " and op name " + this.opName());
        }
        if (!nullArr) {
            for (SDVariable arg : this.args()) {
                this.inputArguments.add(arg.getArr());
            }
        }
        this.outputArguments.clear();
        if (!nullArr) {
            List<long[]> shapes = this.calculateOutputShape();
            SDVariable[] outputVars = this.outputVariables();
            Preconditions.checkState((shapes.size() == outputVars.length ? 1 : 0) != 0, (String)"Mismatch between number of shapes (%s) and number of output variables (%s) - these must match", (int)shapes.size(), (int)outputVars.length);
            this.outputArguments.clear();
            for (int i = 0; i < outputVars.length; ++i) {
                INDArray currArr = outputVars[i].getArr();
                long[] calculatedShape = shapes.get(i);
                if (currArr == null && calculatedShape == null) {
                    throw new ND4JIllegalStateException("Unable to resolve shape for variable " + outputVars[i].getVarName());
                }
                if (currArr == null || !Arrays.equals(currArr.shape(), calculatedShape)) {
                    this.sameDiff.putOrUpdateShapeForVarName(outputVars[i].getVarName(), shapes.get(i), true);
                    currArr = outputVars[i].storeAndAllocateNewArray();
                }
                this.outputArguments.add(currArr);
            }
        }
        if (log.isTraceEnabled()) {
            log.trace("Populating inputs and outputs for op {}: {}", (Object)this.opName, (Object)(nullArr ? "Unsuccessful" : "Successful"));
        }
    }

    @Override
    public FunctionProperties asProperties() {
        return FunctionProperties.builder().name(this.opName()).l(this.iArguments).d(this.tArguments).fieldNames(this.propertiesForFunction()).build();
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> f1) {
        throw new UnsupportedOperationException("Please extend DynamicCustomOp.doDiff to support SameDiff backprop operations. Op: " + this.getClass().getName());
    }

    @Override
    public String toString() {
        return this.opName();
    }

    @Override
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + this.opName());
    }

    @Override
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow op opName found for " + this.opName());
    }

    @Override
    public Op.Type opType() {
        return Op.Type.CUSTOM;
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    }

    @Override
    public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
    }

    public static SameDiffBuilder sameDiffBuilder(String opName, SameDiff sameDiff) {
        return new SameDiffBuilder(opName, sameDiff);
    }

    @Override
    public boolean isInplaceCall() {
        return this.inplaceCall;
    }

    public void setInplaceCall(boolean inplaceCall) {
        this.inplaceCall = inplaceCall;
    }

    public long getHash() {
        return this.hash;
    }

    public static class DynamicCustomOpsBuilder {
        protected String opName;
        protected int numInputs;
        protected int numOutputs;
        protected int numTArguments;
        protected int numIArguments;
        protected boolean inplaceCall;
        protected boolean inplaceAllowed;
        protected long opHash;
        protected List<long[]> outputShapes = new ArrayList<long[]>();
        private List<INDArray> inputArguments = new ArrayList<INDArray>();
        private List<INDArray> outputArguments = new ArrayList<INDArray>();
        private List<Double> tArguments = new ArrayList<Double>();
        private List<Long> iArguments = new ArrayList<Long>();

        protected DynamicCustomOpsBuilder(String opName, long hash, int numInputs, int numOutputs, boolean inplaceAllowed, int numTArguments, int numIArguments) {
            this.opHash = hash;
            this.opName = opName;
            this.numInputs = numInputs;
            this.numOutputs = numOutputs;
            this.numIArguments = numIArguments;
            this.numTArguments = numTArguments;
            this.inplaceAllowed = inplaceAllowed;
        }

        public DynamicCustomOpsBuilder addInputs(INDArray ... inputs) {
            if (this.numInputs >= 0) {
                if (inputs == null) {
                    throw new ND4JIllegalStateException("CustomOp [" + this.opName + "] expects at least " + this.numInputs + " arguments. Null was passed instead.");
                }
                if (this.numInputs > inputs.length) {
                    throw new ND4JIllegalStateException("CustomOp [" + this.opName + "] expects at least " + this.numInputs + " arguments, but " + inputs.length + " was passed to constructor");
                }
            }
            for (INDArray in : inputs) {
                this.inputArguments.add(in);
            }
            return this;
        }

        public DynamicCustomOpsBuilder addOutputs(INDArray ... outputs) {
            if (this.numOutputs >= 0) {
                if (outputs == null) {
                    throw new ND4JIllegalStateException("CustomOp [" + this.opName + "] expects at least " + this.numOutputs + " arguments. Null was passed instead.");
                }
                if (this.numOutputs > outputs.length) {
                    throw new ND4JIllegalStateException("CustomOp [" + this.opName + "] expects at least " + this.numOutputs + " arguments, but " + outputs.length + " was passed to constructor");
                }
            }
            for (INDArray in : outputs) {
                this.outputArguments.add(in);
            }
            return this;
        }

        public DynamicCustomOpsBuilder callInplace(boolean reallyCall) {
            if (reallyCall && !this.inplaceAllowed) {
                throw new ND4JIllegalStateException("Requested op can't be called inplace");
            }
            this.inplaceCall = reallyCall;
            return this;
        }

        public DynamicCustomOpsBuilder addIntegerArguments(List<Integer> iargs) {
            if (this.numIArguments >= 0) {
                if (iargs == null) {
                    throw new ND4JIllegalStateException("CustomOp [" + this.opName + "] expects " + this.numIArguments + " integer arguments. Null was passed instead.");
                }
                if (this.numIArguments > iargs.size()) {
                    throw new ND4JIllegalStateException("CustomOp [" + this.opName + "] expects at least " + this.numIArguments + " integer arguments, but " + iargs.size() + " was passed to constructor");
                }
            }
            for (Integer in : iargs) {
                this.iArguments.add(in.longValue());
            }
            return this;
        }

        public DynamicCustomOpsBuilder addIntegerArguments(long arg) {
            if (this.numIArguments != 1 && this.numIArguments > 0) {
                throw new ND4JIllegalStateException("CustomOp [" + this.opName + "] expects " + this.numIArguments + " integer arguments. One arg was passed instead.");
            }
            this.iArguments.add(arg);
            return this;
        }

        public DynamicCustomOpsBuilder addIntegerArguments(int ... iargs) {
            if (this.numIArguments >= 0) {
                if (iargs == null) {
                    throw new ND4JIllegalStateException("CustomOp [" + this.opName + "] expects at least " + this.numIArguments + " integer arguments. Null was passed instead.");
                }
                if (this.numIArguments > iargs.length) {
                    throw new ND4JIllegalStateException("CustomOp [" + this.opName + "] expects at least " + this.numIArguments + " integer arguments, but " + iargs.length + " was passed to constructor");
                }
            }
            for (int in : iargs) {
                this.iArguments.add(Long.valueOf(in));
            }
            return this;
        }

        public DynamicCustomOpsBuilder addFloatingPointArguments(Double ... targs) {
            if (this.numTArguments >= 0) {
                if (targs == null) {
                    throw new ND4JIllegalStateException("CustomOp [" + this.opName + "] expects at least " + this.numTArguments + " integer arguments. Null was passed instead.");
                }
                if (this.numTArguments > targs.length) {
                    throw new ND4JIllegalStateException("CustomOp [" + this.opName + "] expects at least " + this.numTArguments + " integer arguments, but " + targs.length + " was passed to constructor");
                }
            }
            for (Double in : targs) {
                this.tArguments.add(in);
            }
            return this;
        }

        public DynamicCustomOpsBuilder addOutputShape(int[] shape) {
            this.outputShapes.add(ArrayUtil.toLongArray((int[])shape));
            return this;
        }

        public DynamicCustomOpsBuilder addOutputShape(long[] shape) {
            this.outputShapes.add(shape);
            return this;
        }

        public DynamicCustomOp build() {
            DynamicCustomOp result = new DynamicCustomOp(this.opName);
            result.inputArguments = this.inputArguments;
            result.outputArguments = this.outputArguments;
            result.iArguments = this.iArguments;
            result.tArguments = this.tArguments;
            result.inplaceCall = this.inplaceCall;
            result.hash = this.opHash;
            result.outputShapes = this.outputShapes;
            return result;
        }

        public int getNumOutputs() {
            return -1;
        }
    }

    public static class SameDiffBuilder
    extends DynamicCustomOpsBuilder {
        private SameDiff sameDiff;
        private List<DifferentialFunction> args = new ArrayList<DifferentialFunction>();
        private List<DifferentialFunction> outputs = new ArrayList<DifferentialFunction>();

        private SameDiffBuilder(String opName, SameDiff sameDiff) {
            this(opName, sameDiff, 0L, 0, 0, false, 0, 0);
        }

        protected SameDiffBuilder(String opName, SameDiff sameDiff, long hash, int numInputs, int numOutputs, boolean inplaceAllowed, int numTArguments, int numIArguments) {
            super(opName, hash, numInputs, numOutputs, inplaceAllowed, numTArguments, numIArguments);
            this.sameDiff = sameDiff;
        }

        public SameDiffBuilder sameDiff(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
            return this;
        }

        @Override
        public DynamicCustomOpsBuilder addInputs(INDArray ... inputs) {
            throw new UnsupportedOperationException("Unable to add direct ndarrays. Please use the normal builder for that.");
        }

        @Override
        public DynamicCustomOpsBuilder addOutputs(INDArray ... outputs) {
            throw new UnsupportedOperationException("Unable to add direct ndarrays. Please use the normal builder for that.");
        }

        public DynamicCustomOpsBuilder addInputs(DifferentialFunction ... inputs) {
            for (DifferentialFunction function : inputs) {
                this.args.add(function);
            }
            return this;
        }

        public DynamicCustomOpsBuilder addOutputs(DifferentialFunction ... outputs) {
            this.outputs.addAll(Arrays.asList(outputs));
            return this;
        }

        @Override
        public DynamicCustomOp build() {
            DynamicCustomOp ret = super.build();
            ret.setSameDiff(this.sameDiff);
            ret.outputShapes = this.outputShapes;
            this.sameDiff.putFunctionForId(ret.getOwnName(), ret);
            if (this.outputs.isEmpty() && !this.outputShapes.isEmpty()) {
                for (int i = 0; i < this.outputShapes.size(); ++i) {
                    this.outputs.add(this.sameDiff.var(this.sameDiff.generateNewVarName("dynamiccustomop", i), (long[])this.outputShapes.get(i)));
                }
            }
            this.sameDiff.putFunctionForId(ret.getOwnName(), ret);
            ret.outputVariables = this.outputs.toArray(new SDVariable[this.outputs.size()]);
            return ret;
        }
    }
}

