public class Trainer
extends java.lang.Object
implements java.lang.AutoCloseable
Trainer interface provides a session for model training.
Trainer provides an easy, and manageable interface for training. Trainer is
not thread-safe.
See the tutorials on:
| Constructor and Description |
|---|
Trainer(Model model,
TrainingConfig trainingConfig)
|
| Modifier and Type | Method and Description |
|---|---|
void |
addMetric(java.lang.String metricName,
long begin)
Helper to add a metric for a time difference.
|
void |
close() |
NDList |
evaluate(NDList input)
Evaluates function of the model once on the given input
NDList. |
protected void |
finalize() |
NDList |
forward(NDList input)
Applies the forward function of the model once on the given input
NDList. |
NDList |
forward(NDList data,
NDList labels)
Applies the forward function of the model once with both data and labels.
|
DataManager |
getDataManager()
Returns the
DataManager. |
Device[] |
getDevices()
Returns the devices used for training.
|
java.util.List<Evaluator> |
getEvaluators()
Gets all
Evaluators. |
Loss |
getLoss()
Gets the training
Loss function of the trainer. |
NDManager |
getManager()
Gets the
NDManager from the model. |
Metrics |
getMetrics()
Returns the Metrics param used for benchmarking.
|
Model |
getModel()
Returns the model used to create this trainer.
|
TrainingResult |
getTrainingResult()
Returns the
TrainingResult. |
void |
initialize(Shape... shapes)
Initializes the
Model that the Trainer is going to train. |
java.lang.Iterable<Batch> |
iterateDataset(Dataset dataset)
Fetches an iterator that can iterate through the given
Dataset. |
GradientCollector |
newGradientCollector()
Returns a new instance of
GradientCollector. |
void |
notifyListeners(java.util.function.Consumer<TrainingListener> listenerConsumer)
Executes a method on each of the
TrainingListeners. |
void |
setMetrics(Metrics metrics)
Attaches a Metrics param to use for benchmarking.
|
void |
step()
Updates all of the parameters of the model once.
|
public Trainer(Model model, TrainingConfig trainingConfig)
model - the model the trainer will train ontrainingConfig - the configuration used by the trainerpublic void initialize(Shape... shapes)
Model that the Trainer is going to train.shapes - an array of Shape of the inputspublic java.lang.Iterable<Batch> iterateDataset(Dataset dataset) throws java.io.IOException, TranslateException
Dataset.dataset - the dataset to iterate throughIterable of Batch that contains batches of data from the datasetjava.io.IOException - for various exceptions depending on the datasetTranslateException - if there is an error while processing inputpublic GradientCollector newGradientCollector()
GradientCollector.GradientCollectorpublic NDList forward(NDList input)
NDList.input - the input NDListpublic NDList forward(NDList data, NDList labels)
public NDList evaluate(NDList input)
NDList.input - the input NDListpublic void step()
public Metrics getMetrics()
public void setMetrics(Metrics metrics)
metrics - the Metrics classpublic Device[] getDevices()
public Loss getLoss()
Loss function of the trainer.Loss functionpublic Model getModel()
public DataManager getDataManager()
DataManager.DataManagerpublic java.util.List<Evaluator> getEvaluators()
Evaluators.public void notifyListeners(java.util.function.Consumer<TrainingListener> listenerConsumer)
TrainingListeners.listenerConsumer - a consumer that executes the methodpublic TrainingResult getTrainingResult()
TrainingResult.TrainingResultprotected void finalize()
throws java.lang.Throwable
finalize in class java.lang.Objectjava.lang.Throwablepublic void close()
close in interface java.lang.AutoCloseablepublic void addMetric(java.lang.String metricName,
long begin)
metricName - the metric namebegin - the time difference start (this method is called at the time difference end)