public class TrainableWordEmbedding extends Embedding<java.lang.String> implements WordEmbedding
TrainableWordEmbedding is an implementation of WordEmbedding and Embedding based on a SimpleVocabulary. This WordEmbedding is ideal when there
are no pre-trained embeddings available.| Modifier and Type | Class and Description |
|---|---|
static class |
TrainableWordEmbedding.Builder
A builder for a
TrainableWordEmbedding. |
Embedding.BaseBuilder<T,B extends Embedding.BaseBuilder<T,B>>, Embedding.DefaultEmbedding, Embedding.DefaultItemdataType, embedding, embeddingSize, fallthroughEmbedding, numItems, sparseGradchildren, inputNames, inputShapes, parameters, parameterShapeCallbacks, version| Constructor and Description |
|---|
TrainableWordEmbedding(NDArray embedding,
java.util.List<java.lang.String> items)
Constructs a pretrained embedding.
|
TrainableWordEmbedding(NDArray embedding,
java.util.List<java.lang.String> items,
boolean sparseGrad)
Constructs a pretrained embedding.
|
TrainableWordEmbedding(TrainableWordEmbedding.Builder builder)
Constructs a new instance of
TrainableWordEmbedding from the TrainableWordEmbedding.Builder. |
TrainableWordEmbedding(Vocabulary vocabulary,
int embeddingSize)
Constructs a new instance of
TrainableWordEmbedding from a SimpleVocabulary
and a given embedding size. |
| Modifier and Type | Method and Description |
|---|---|
static TrainableWordEmbedding.Builder |
builder()
Creates a builder to build an
Embedding. |
java.lang.String |
decode(byte[] byteArray)
Decodes the given byte array into an object of input parameter type.
|
long |
embed(java.lang.String item)
Embeds an item.
|
NDArray |
embedWord(NDArray index)
Embeds the word after preprocessed using
WordEmbedding.preprocessWordToEmbed(String). |
byte[] |
encode(java.lang.String input)
Encodes an object of input type into a byte array.
|
boolean |
hasItem(java.lang.String item)
Returns whether an item is in the embedding.
|
long |
preprocessWordToEmbed(java.lang.String word)
Pre-processes the word to embed into an array to pass into the model.
|
java.util.Optional<java.lang.String> |
unembed(long index)
Returns the item corresponding to the given index.
|
java.lang.String |
unembedWord(NDArray word)
Returns the closest matching word for the given index.
|
boolean |
vocabularyContains(java.lang.String word)
Returns whether an embedding exists for a word.
|
embed, forward, getOutputShapes, loadParameters, saveParametersaddChildBlock, addParameter, addParameter, addParameter, beforeInitialize, cast, clear, describeInput, getChildren, getDirectParameters, getParameters, getParameterShape, initialize, initializeChildBlocks, isInitialized, loadMetadata, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, toStringclone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitembedWord, embedWordforward, forward, validateLayoutpublic TrainableWordEmbedding(TrainableWordEmbedding.Builder builder)
TrainableWordEmbedding from the TrainableWordEmbedding.Builder.builder - the TrainableWordEmbedding.Builderpublic TrainableWordEmbedding(Vocabulary vocabulary, int embeddingSize)
TrainableWordEmbedding from a SimpleVocabulary
and a given embedding size.vocabulary - a Vocabulary to get tokens fromembeddingSize - the required embedding sizepublic TrainableWordEmbedding(NDArray embedding, java.util.List<java.lang.String> items)
embedding - the embedding arrayitems - the items in the embedding (in matching order to the embedding array)public TrainableWordEmbedding(NDArray embedding, java.util.List<java.lang.String> items, boolean sparseGrad)
embedding - the embedding arrayitems - the items in the embedding (in matching order to the embedding array)sparseGrad - whether to compute row sparse gradient in the backward calculationpublic boolean vocabularyContains(java.lang.String word)
vocabularyContains in interface WordEmbeddingword - the word to checkpublic long preprocessWordToEmbed(java.lang.String word)
Make sure to call WordEmbedding.embedWord(NDManager, long) after this.
preprocessWordToEmbed in interface WordEmbeddingword - the word to embedpublic NDArray embedWord(NDArray index) throws EmbeddingException
WordEmbeddingWordEmbedding.preprocessWordToEmbed(String).embedWord in interface WordEmbeddingindex - the index of the word to embedEmbeddingException - if there is an error while trying to embedpublic java.lang.String unembedWord(NDArray word)
unembedWord in interface WordEmbeddingword - the word embedding to find the matching string word for.public byte[] encode(java.lang.String input)
Embedding objects.encode in interface AbstractIndexedEmbedding<java.lang.String>input - the input object to be encodedpublic java.lang.String decode(byte[] byteArray)
decode in interface AbstractIndexedEmbedding<java.lang.String>byteArray - the byte array to be decodedpublic long embed(java.lang.String item)
AbstractIndexedEmbeddingembed in interface AbstractIndexedEmbedding<java.lang.String>item - the item to embedpublic java.util.Optional<java.lang.String> unembed(long index)
AbstractIndexedEmbeddingunembed in interface AbstractIndexedEmbedding<java.lang.String>index - the indexpublic static TrainableWordEmbedding.Builder builder()
Embedding.public boolean hasItem(java.lang.String item)
AbstractEmbeddinghasItem in interface AbstractEmbedding<java.lang.String>item - the item to test