public class LSTM extends RecurrentBlock
LSTM is an implementation of recurrent neural networks which applies Long Short-Term
Memory recurrent layer to input.
Reference paper - LONG SHORT-TERM MEMORY - Hochreiter, 1997. http://www.bioinf.jku.at/publications/older/2604.pdf
The LSTM operator is formulated as below:
$$ \begin{split}\begin{array}{ll} i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\ o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ c_t = f_t * c_{(t-1)} + i_t * g_t \\ h_t = o_t * \tanh(c_t) \end{array}\end{split} $$
| Modifier and Type | Class and Description |
|---|---|
static class |
LSTM.Builder
|
RecurrentBlock.BaseBuilder<T extends RecurrentBlock.BaseBuilder>beginState, dropRate, gates, mode, numDirections, numStackedLayers, stateOutputs, stateSize, useSequenceLengthchildren, inputNames, inputShapes, parameters, parameterShapeCallbacks, version| Modifier and Type | Method and Description |
|---|---|
static LSTM.Builder |
builder()
Creates a builder to build a
LSTM. |
NDList |
forward(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once.
|
protected NDList |
opInputs(ParameterStore parameterStore,
NDList inputs,
boolean training) |
protected void |
resetBeginStates() |
void |
setBeginStates(NDList beginStates)
Sets the initial
NDArray value for the hidden states. |
beforeInitialize, getOutputShapes, getParameterShape, isBidirectional, loadMetadata, setStateOutputs, updateInputLayoutToTNC, validateInputSizeaddChildBlock, addParameter, addParameter, addParameter, cast, clear, describeInput, getChildren, getDirectParameters, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, toStringclone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitforward, forward, validateLayoutpublic NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
forward in interface Blockforward in class RecurrentBlockparameterStore - the parameter storeinputs - the input NDListtraining - true for a training forward passparams - optional parameterspublic void setBeginStates(NDList beginStates)
NDArray value for the hidden states.setBeginStates in class RecurrentBlockbeginStates - the NDArray value for the hidden statesprotected void resetBeginStates()
resetBeginStates in class RecurrentBlockprotected NDList opInputs(ParameterStore parameterStore, NDList inputs, boolean training)
opInputs in class RecurrentBlockpublic static LSTM.Builder builder()
LSTM.