public final class EasyTrain
extends java.lang.Object
| Modifier and Type | Method and Description |
|---|---|
static void |
fit(Trainer trainer,
int numEpoch,
Dataset trainingDataset,
Dataset validateDataset)
Runs a basic epoch training experience with a given trainer.
|
static void |
trainBatch(Trainer trainer,
Batch batch)
Trains the model with one iteration of the given
Batch of data. |
static void |
validateBatch(Trainer trainer,
Batch batch)
Validates the given batch of data.
|
public static void fit(Trainer trainer, int numEpoch, Dataset trainingDataset, Dataset validateDataset) throws java.io.IOException, TranslateException
trainer - the trainer to train fornumEpoch - the number of epochs to traintrainingDataset - the dataset to train onvalidateDataset - the dataset to validate against. Can be null for no validationjava.io.IOException - for various exceptions depending on the datasetTranslateException - if there is an error while processing inputpublic static void trainBatch(Trainer trainer, Batch batch)
Batch of data.trainer - the trainer to validate the batch withbatch - a Batch that contains data, and its respective labelsjava.lang.IllegalArgumentException - if the batch engine does not match the trainer enginepublic static void validateBatch(Trainer trainer, Batch batch)
During validation, the evaluators and losses are computed, but gradients aren't computed, and parameters aren't updated.
trainer - the trainer to validate the batch withbatch - a Batch of datajava.lang.IllegalArgumentException - if the batch engine does not match the trainer engine