public abstract class BaseScoreCalculator<T extends Model> extends Object implements ScoreCalculator<T>
| Modifier and Type | Field and Description |
|---|---|
protected int |
exampleCount |
protected org.nd4j.linalg.dataset.api.iterator.DataSetIterator |
iterator |
protected org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator |
mdsIterator |
protected int |
minibatchCount |
protected double |
scoreSum |
| Modifier | Constructor and Description |
|---|---|
protected |
BaseScoreCalculator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator) |
protected |
BaseScoreCalculator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator iterator) |
| Modifier and Type | Method and Description |
|---|---|
static org.nd4j.linalg.api.ndarray.INDArray[] |
arr(org.nd4j.linalg.api.ndarray.INDArray in) |
double |
calculateScore(T network)
Calculate the score for the given MultiLayerNetwork
|
protected abstract double |
finalScore(double scoreSum,
int minibatchCount,
int exampleCount) |
static org.nd4j.linalg.api.ndarray.INDArray |
get0(org.nd4j.linalg.api.ndarray.INDArray[] in) |
protected abstract org.nd4j.linalg.api.ndarray.INDArray[] |
output(T network,
org.nd4j.linalg.api.ndarray.INDArray[] input,
org.nd4j.linalg.api.ndarray.INDArray[] fMask,
org.nd4j.linalg.api.ndarray.INDArray[] lMask) |
protected abstract org.nd4j.linalg.api.ndarray.INDArray |
output(T network,
org.nd4j.linalg.api.ndarray.INDArray input,
org.nd4j.linalg.api.ndarray.INDArray fMask,
org.nd4j.linalg.api.ndarray.INDArray lMask) |
protected abstract void |
reset() |
protected abstract double |
scoreMinibatch(T network,
org.nd4j.linalg.api.ndarray.INDArray[] features,
org.nd4j.linalg.api.ndarray.INDArray[] labels,
org.nd4j.linalg.api.ndarray.INDArray[] fMask,
org.nd4j.linalg.api.ndarray.INDArray[] lMask,
org.nd4j.linalg.api.ndarray.INDArray[] output) |
protected double |
scoreMinibatch(T network,
org.nd4j.linalg.api.ndarray.INDArray features,
org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray fMask,
org.nd4j.linalg.api.ndarray.INDArray lMask,
org.nd4j.linalg.api.ndarray.INDArray output) |
protected org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator mdsIterator
protected org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator
protected double scoreSum
protected int minibatchCount
protected int exampleCount
protected BaseScoreCalculator(@NonNull
org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator)
protected BaseScoreCalculator(@NonNull
org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator iterator)
public double calculateScore(T network)
ScoreCalculatorcalculateScore in interface ScoreCalculator<T extends Model>protected abstract void reset()
protected abstract org.nd4j.linalg.api.ndarray.INDArray output(T network, org.nd4j.linalg.api.ndarray.INDArray input, org.nd4j.linalg.api.ndarray.INDArray fMask, org.nd4j.linalg.api.ndarray.INDArray lMask)
protected abstract org.nd4j.linalg.api.ndarray.INDArray[] output(T network, org.nd4j.linalg.api.ndarray.INDArray[] input, org.nd4j.linalg.api.ndarray.INDArray[] fMask, org.nd4j.linalg.api.ndarray.INDArray[] lMask)
protected double scoreMinibatch(T network, org.nd4j.linalg.api.ndarray.INDArray features, org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray fMask, org.nd4j.linalg.api.ndarray.INDArray lMask, org.nd4j.linalg.api.ndarray.INDArray output)
protected abstract double scoreMinibatch(T network, org.nd4j.linalg.api.ndarray.INDArray[] features, org.nd4j.linalg.api.ndarray.INDArray[] labels, org.nd4j.linalg.api.ndarray.INDArray[] fMask, org.nd4j.linalg.api.ndarray.INDArray[] lMask, org.nd4j.linalg.api.ndarray.INDArray[] output)
protected abstract double finalScore(double scoreSum,
int minibatchCount,
int exampleCount)
public static org.nd4j.linalg.api.ndarray.INDArray[] arr(org.nd4j.linalg.api.ndarray.INDArray in)
public static org.nd4j.linalg.api.ndarray.INDArray get0(org.nd4j.linalg.api.ndarray.INDArray[] in)
Copyright © 2018. All rights reserved.