public abstract class BaseFunction<T extends Tensor> extends java.lang.Object implements Function<T>
functions| Modifier and Type | Field and Description |
|---|---|
protected int |
miniBatchSize
Number of inputs in the mini-batch
|
protected java.util.List<T> |
parameters |
protected int[] |
shapeInput |
protected int[] |
shapeOutput |
protected java.util.List<int[]> |
shapeParameters |
| Constructor and Description |
|---|
BaseFunction() |
| Modifier and Type | Method and Description |
|---|---|
abstract void |
_forward(T input,
T output) |
abstract void |
_initialize() |
abstract void |
_setParameters(java.util.List<T> parameters) |
void |
forward(T input,
T output)
Performs forward pass of the function on the provided inputs.
|
int[] |
getOutputShape()
Returns the output tensor's shape, without the mini-batch dimension.
|
java.util.List<T> |
getParameters()
If the parameters have been set, then this returns the list of parameters.
|
java.util.List<int[]> |
getParameterShapes()
Returns the shape of input tensors, without the mini-batch dimension.
|
void |
initialize(int... shapeInput)
Initializes internal data structures given the shape of the input tensor, minus the stacked input
dimension.
|
void |
setParameters(java.util.List<T> parameters)
Specifies learnable function parameters, e.g.
|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitgetTensorTypeprotected int[] shapeInput
protected java.util.List<int[]> shapeParameters
protected int[] shapeOutput
protected int miniBatchSize
public void initialize(int... shapeInput)
Functioninitialize in interface Function<T extends Tensor>shapeInput - Shape of the input tensorpublic abstract void _initialize()
public void setParameters(java.util.List<T> parameters)
FunctionSpecifies learnable function parameters, e.g. weights for linear functions. This function only
needs to be called once each time a parameter has been modified. Must be called before Function.forward(T, T).
setParameters in interface Function<T extends Tensor>parameters - Tensors containing parameters which are optimized. Not modified.public abstract void _setParameters(java.util.List<T> parameters)
public java.util.List<T> getParameters()
FunctiongetParameters in interface Function<T extends Tensor>public void forward(T input, T output)
FunctionInput tensor shape = (N,variable ... ) - N is the mini-batch size - Other dimensions are implementation specific.
public java.util.List<int[]> getParameterShapes()
FunctionFunction.initialize(int...) has been called.getParameterShapes in interface Function<T extends Tensor>Function.initialize(int...).public int[] getOutputShape()
FunctionFunction.initialize(int...) has been called.getOutputShape in interface Function<T extends Tensor>Function.initialize(int...).