public class VAEReconProbScoreCalculator extends BaseScoreCalculator<Model>
VariationalAutoencoder.reconstructionProbability(INDArray, int) for more details| Modifier and Type | Field and Description |
|---|---|
protected boolean |
average |
protected boolean |
logProb |
protected int |
reconstructionProbNumSamples |
exampleCount, iterator, mdsIterator, minibatchCount, scoreSum| Constructor and Description |
|---|
VAEReconProbScoreCalculator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator,
int reconstructionProbNumSamples,
boolean logProb)
Constructor for average reconstruction probability
|
VAEReconProbScoreCalculator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator,
int reconstructionProbNumSamples,
boolean logProb,
boolean average)
Constructor for reconstruction probability
|
| Modifier and Type | Method and Description |
|---|---|
protected double |
finalScore(double scoreSum,
int minibatchCount,
int exampleCount) |
protected org.nd4j.linalg.api.ndarray.INDArray[] |
output(Model network,
org.nd4j.linalg.api.ndarray.INDArray[] input,
org.nd4j.linalg.api.ndarray.INDArray[] fMask,
org.nd4j.linalg.api.ndarray.INDArray[] lMask) |
protected org.nd4j.linalg.api.ndarray.INDArray |
output(Model network,
org.nd4j.linalg.api.ndarray.INDArray input,
org.nd4j.linalg.api.ndarray.INDArray fMask,
org.nd4j.linalg.api.ndarray.INDArray lMask) |
protected void |
reset() |
protected double |
scoreMinibatch(Model network,
org.nd4j.linalg.api.ndarray.INDArray[] features,
org.nd4j.linalg.api.ndarray.INDArray[] labels,
org.nd4j.linalg.api.ndarray.INDArray[] fMask,
org.nd4j.linalg.api.ndarray.INDArray[] lMask,
org.nd4j.linalg.api.ndarray.INDArray[] output) |
protected double |
scoreMinibatch(Model net,
org.nd4j.linalg.api.ndarray.INDArray features,
org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray fMask,
org.nd4j.linalg.api.ndarray.INDArray lMask,
org.nd4j.linalg.api.ndarray.INDArray output) |
arr, calculateScore, get0protected final int reconstructionProbNumSamples
protected final boolean logProb
protected final boolean average
public VAEReconProbScoreCalculator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator,
int reconstructionProbNumSamples,
boolean logProb)
iterator - IteratorreconstructionProbNumSamples - Number of samples. See VariationalAutoencoder.reconstructionProbability(INDArray, int)
for detailslogProb - If true: calculate (negative) log probability. False: probabilitypublic VAEReconProbScoreCalculator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator,
int reconstructionProbNumSamples,
boolean logProb,
boolean average)
iterator - IteratorreconstructionProbNumSamples - Number of samples. See VariationalAutoencoder.reconstructionProbability(INDArray, int)
for detailslogProb - If true: calculate (negative) log probability. False: probabilityaverage - If true: return average (log) probability. False: sum of log probability.protected void reset()
reset in class BaseScoreCalculator<Model>protected org.nd4j.linalg.api.ndarray.INDArray output(Model network, org.nd4j.linalg.api.ndarray.INDArray input, org.nd4j.linalg.api.ndarray.INDArray fMask, org.nd4j.linalg.api.ndarray.INDArray lMask)
output in class BaseScoreCalculator<Model>protected org.nd4j.linalg.api.ndarray.INDArray[] output(Model network, org.nd4j.linalg.api.ndarray.INDArray[] input, org.nd4j.linalg.api.ndarray.INDArray[] fMask, org.nd4j.linalg.api.ndarray.INDArray[] lMask)
output in class BaseScoreCalculator<Model>protected double scoreMinibatch(Model net, org.nd4j.linalg.api.ndarray.INDArray features, org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray fMask, org.nd4j.linalg.api.ndarray.INDArray lMask, org.nd4j.linalg.api.ndarray.INDArray output)
scoreMinibatch in class BaseScoreCalculator<Model>protected double scoreMinibatch(Model network, org.nd4j.linalg.api.ndarray.INDArray[] features, org.nd4j.linalg.api.ndarray.INDArray[] labels, org.nd4j.linalg.api.ndarray.INDArray[] fMask, org.nd4j.linalg.api.ndarray.INDArray[] lMask, org.nd4j.linalg.api.ndarray.INDArray[] output)
scoreMinibatch in class BaseScoreCalculator<Model>protected double finalScore(double scoreSum,
int minibatchCount,
int exampleCount)
finalScore in class BaseScoreCalculator<Model>Copyright © 2018. All rights reserved.