public class EvaluativeListener extends BaseTrainingListener
| Modifier and Type | Field and Description |
|---|---|
protected EvaluationCallback |
callback
This callback will be invoked after evaluation finished
|
protected org.nd4j.linalg.dataset.DataSet |
ds |
protected org.nd4j.linalg.dataset.api.iterator.DataSetIterator |
dsIterator |
protected IEvaluation[] |
evaluations |
protected int |
frequency |
protected AtomicLong |
invocationCount |
protected InvocationType |
invocationType |
protected ThreadLocal<AtomicLong> |
iterationCount |
protected org.nd4j.linalg.dataset.MultiDataSet |
mds |
protected org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator |
mdsIterator |
| Constructor and Description |
|---|
EvaluativeListener(org.nd4j.linalg.dataset.DataSet dataSet,
int frequency,
InvocationType type) |
EvaluativeListener(org.nd4j.linalg.dataset.DataSet dataSet,
int frequency,
InvocationType type,
IEvaluation... evaluations) |
EvaluativeListener(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator,
int frequency)
Evaluation will be launched after each *frequency* iteration
|
EvaluativeListener(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator,
int frequency,
IEvaluation... evaluations)
Evaluation will be launched after each *frequency* iteration
|
EvaluativeListener(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator,
int frequency,
InvocationType type) |
EvaluativeListener(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator,
int frequency,
InvocationType type,
IEvaluation... evaluations) |
EvaluativeListener(org.nd4j.linalg.dataset.MultiDataSet multiDataSet,
int frequency,
InvocationType type) |
EvaluativeListener(org.nd4j.linalg.dataset.MultiDataSet multiDataSet,
int frequency,
InvocationType type,
IEvaluation... evaluations) |
EvaluativeListener(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator iterator,
int frequency)
Evaluation will be launched after each *frequency* iteration
|
EvaluativeListener(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator iterator,
int frequency,
IEvaluation... evaluations)
Evaluation will be launched after each *frequency* iteration
|
EvaluativeListener(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator iterator,
int frequency,
InvocationType type) |
EvaluativeListener(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator iterator,
int frequency,
InvocationType type,
IEvaluation... evaluations) |
| Modifier and Type | Method and Description |
|---|---|
protected void |
evalAtIndex(IEvaluation evaluation,
org.nd4j.linalg.api.ndarray.INDArray[] labels,
org.nd4j.linalg.api.ndarray.INDArray[] predictions,
int index) |
protected void |
invokeListener(Model model) |
void |
iterationDone(Model model,
int iteration,
int epoch)
Event listener for each iteration
|
void |
onBackwardPass(Model model)
Called once per iteration (backward pass) after gradients have been calculated, and updated
Gradients are available via
Model.gradient(). |
void |
onEpochEnd(Model model)
Called once at the end of each epoch, when using methods such as
MultiLayerNetwork.fit(DataSetIterator),
ComputationGraph.fit(DataSetIterator) or ComputationGraph.fit(MultiDataSetIterator) |
void |
onEpochStart(Model model)
Called once at the start of each epoch, when using methods such as
MultiLayerNetwork.fit(DataSetIterator),
ComputationGraph.fit(DataSetIterator) or ComputationGraph.fit(MultiDataSetIterator) |
void |
onForwardPass(Model model,
List<org.nd4j.linalg.api.ndarray.INDArray> activations)
Called once per iteration (forward pass) for activations (usually for a
MultiLayerNetwork),
only at training time |
void |
onForwardPass(Model model,
Map<String,org.nd4j.linalg.api.ndarray.INDArray> activations)
Called once per iteration (forward pass) for activations (usually for a
ComputationGraph),
only at training time |
void |
onGradientCalculation(Model model)
Called once per iteration (backward pass) before the gradients are updated
Gradients are available via
Model.gradient(). |
protected transient ThreadLocal<AtomicLong> iterationCount
protected int frequency
protected AtomicLong invocationCount
protected transient org.nd4j.linalg.dataset.api.iterator.DataSetIterator dsIterator
protected transient org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator mdsIterator
protected org.nd4j.linalg.dataset.DataSet ds
protected org.nd4j.linalg.dataset.MultiDataSet mds
protected IEvaluation[] evaluations
protected InvocationType invocationType
protected transient EvaluationCallback callback
public EvaluativeListener(@NonNull
org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator,
int frequency)
iterator - frequency - public EvaluativeListener(@NonNull
org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator,
int frequency,
@NonNull
InvocationType type)
public EvaluativeListener(@NonNull
org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator iterator,
int frequency)
iterator - frequency - public EvaluativeListener(@NonNull
org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator iterator,
int frequency,
@NonNull
InvocationType type)
public EvaluativeListener(@NonNull
org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator,
int frequency,
IEvaluation... evaluations)
iterator - frequency - public EvaluativeListener(@NonNull
org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator,
int frequency,
@NonNull
InvocationType type,
IEvaluation... evaluations)
public EvaluativeListener(@NonNull
org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator iterator,
int frequency,
IEvaluation... evaluations)
iterator - frequency - public EvaluativeListener(@NonNull
org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator iterator,
int frequency,
@NonNull
InvocationType type,
IEvaluation... evaluations)
public EvaluativeListener(@NonNull
org.nd4j.linalg.dataset.DataSet dataSet,
int frequency,
@NonNull
InvocationType type)
public EvaluativeListener(@NonNull
org.nd4j.linalg.dataset.MultiDataSet multiDataSet,
int frequency,
@NonNull
InvocationType type)
public EvaluativeListener(@NonNull
org.nd4j.linalg.dataset.DataSet dataSet,
int frequency,
@NonNull
InvocationType type,
IEvaluation... evaluations)
public EvaluativeListener(@NonNull
org.nd4j.linalg.dataset.MultiDataSet multiDataSet,
int frequency,
@NonNull
InvocationType type,
IEvaluation... evaluations)
public void iterationDone(Model model, int iteration, int epoch)
iterationDone in interface TrainingListeneriterationDone in class BaseTrainingListenermodel - the model iteratingiteration - the iterationpublic void onEpochStart(Model model)
TrainingListenerMultiLayerNetwork.fit(DataSetIterator),
ComputationGraph.fit(DataSetIterator) or ComputationGraph.fit(MultiDataSetIterator)onEpochStart in interface TrainingListeneronEpochStart in class BaseTrainingListenerpublic void onEpochEnd(Model model)
TrainingListenerMultiLayerNetwork.fit(DataSetIterator),
ComputationGraph.fit(DataSetIterator) or ComputationGraph.fit(MultiDataSetIterator)onEpochEnd in interface TrainingListeneronEpochEnd in class BaseTrainingListenerpublic void onForwardPass(Model model, List<org.nd4j.linalg.api.ndarray.INDArray> activations)
TrainingListenerMultiLayerNetwork),
only at training timeonForwardPass in interface TrainingListeneronForwardPass in class BaseTrainingListenermodel - Modelactivations - Layer activations (including input)public void onForwardPass(Model model, Map<String,org.nd4j.linalg.api.ndarray.INDArray> activations)
TrainingListenerComputationGraph),
only at training timeonForwardPass in interface TrainingListeneronForwardPass in class BaseTrainingListenermodel - Modelactivations - Layer activations (including input)public void onGradientCalculation(Model model)
TrainingListenerModel.gradient().
Note that gradients will likely be updated in-place - thus they should be copied or processed synchronously
in this method.
For updates (gradients post learning rate/momentum/rmsprop etc) see TrainingListener.onBackwardPass(Model)
onGradientCalculation in interface TrainingListeneronGradientCalculation in class BaseTrainingListenermodel - Modelpublic void onBackwardPass(Model model)
TrainingListenerModel.gradient().
Unlike TrainingListener.onGradientCalculation(Model) the gradients at this point will be post-update, rather than
raw (pre-update) gradients at that method call.
onBackwardPass in interface TrainingListeneronBackwardPass in class BaseTrainingListenermodel - Modelprotected void invokeListener(Model model)
protected void evalAtIndex(IEvaluation evaluation, org.nd4j.linalg.api.ndarray.INDArray[] labels, org.nd4j.linalg.api.ndarray.INDArray[] predictions, int index)
Copyright © 2018. All rights reserved.