public interface Function<T extends Tensor>
High level interface for functions in an Artificial Neural Network. This interface only defines the
the operations in the forwards pass. When learning a network the gradient is typically needed and
those additional operations can be found in DFunction, which extends this interface.
Forwards only implementations potentially have a lower memory foot print, faster specialized implementations, more simplistic implementations.
| Modifier and Type | Method and Description |
|---|---|
void |
forward(T input,
T output)
Performs forward pass of the function on the provided inputs.
|
int[] |
getOutputShape()
Returns the output tensor's shape, without the mini-batch dimension.
|
java.util.List<T> |
getParameters()
If the parameters have been set, then this returns the list of parameters.
|
java.util.List<int[]> |
getParameterShapes()
Returns the shape of input tensors, without the mini-batch dimension.
|
java.lang.Class<T> |
getTensorType()
Returns the type of tensor it can process
|
void |
initialize(int... shapeInput)
Initializes internal data structures given the shape of the input tensor, minus the stacked input
dimension.
|
void |
setParameters(java.util.List<T> parameters)
Specifies learnable function parameters, e.g.
|
void initialize(int... shapeInput)
shapeInput - Shape of the input tensorjava.lang.IllegalArgumentException - If input tensor shapes are not validvoid setParameters(java.util.List<T> parameters)
Specifies learnable function parameters, e.g. weights for linear functions. This function only
needs to be called once each time a parameter has been modified. Must be called before forward(T, T).
parameters - Tensors containing parameters which are optimized. Not modified.java.util.List<T> getParameters()
void forward(T input, T output)
Input tensor shape = (N,variable ... ) - N is the mini-batch size - Other dimensions are implementation specific.
input - Input to the function.output - Output tensor. Modified.java.util.List<int[]> getParameterShapes()
initialize(int...) has been called.initialize(int...).int[] getOutputShape()
initialize(int...) has been called.initialize(int...).java.lang.Class<T> getTensorType()