public class MultinomialVertex extends IntegerVertex implements ProbabilisticInteger, SamplableWithManyScalars<IntegerTensor>
The most common use case is a single scalar value for n (trials) and a vector of p (probabilities): e.g. n = 5 with shape () p = [0.2, 0.2, 0.6] with shape (3) a sample could return x = [1, 3, 1] with shape (3) and logProb([1, 3, 1]) would be valid and logProb([[1, 3, 1], [2, 2, 1]]) would be a batch logProb equivalent to logProb([1, 3, 1]) + logProb([2, 2, 1])
More complex cases are also acceptable and use broadcasting semantics.
If the number of categories is defined by k, then the shape of p is (a...b, k) where a...b represents any shape of any rank. For the p as a vector case, a...b is rank 0 and would be just a shape (k). Given that p has a shape of (a...b, k) then n can have any shape that is broadcastable with a...b. The resulting shape would be the broadcasted n shape with a...b and end in k. e.g. n = [[1, 2],[3, 4]] with shape (2, 2) p = [[0.2, 0.2, 0.6], [0.5, 0.25, 0.25]] with shape 2, 3 therefore k = 3 and the result shape is (2, 2, 3), which is (2, 2) broadcasted with (2) and k appended.
| Constructor and Description |
|---|
MultinomialVertex(int n,
DoubleTensor p) |
MultinomialVertex(int n,
DoubleVertex p) |
MultinomialVertex(IntegerTensor n,
DoubleTensor p) |
MultinomialVertex(IntegerTensor n,
DoubleVertex p) |
MultinomialVertex(IntegerVertex n,
DoubleVertex p) |
MultinomialVertex(long[] tensorShape,
IntegerVertex n,
DoubleVertex p) |
| Modifier and Type | Method and Description |
|---|---|
java.util.Map<Vertex,DoubleTensor> |
dLogProb(IntegerTensor value,
java.util.Set<? extends Vertex> withRespectTo)
The partial derivatives of the natural log prob.
|
IntegerVertex |
getN() |
DoubleVertex |
getP() |
double |
logProb(IntegerTensor xTensor)
This is the natural log of the probability at the supplied value.
|
IntegerTensor |
sampleWithShape(long[] shape,
KeanuRandom random) |
abs, concat, div, div, divideBy, divideBy, divideBy, equalTo, getValue, greaterThan, greaterThanOrEqualTo, lambda, lambda, lessThan, lessThanOrEqualTo, loadValue, max, min, minus, minus, minus, multiply, multiply, multiply, notEqualTo, observe, observe, plus, plus, plus, pow, pow, reshape, reverseDiv, reverseMinus, saveValue, setAndCascade, setAndCascade, setValue, setValue, slice, sum, sum, take, times, times, toDouble, unaryMinusaddChild, addParent, addParents, equals, eval, getChildren, getConnectedGraph, getDegree, getId, getIndentation, getLabel, getObservedValue, getParents, getRank, getReference, getShape, getState, getValue, hashCode, hasValue, isDifferentiable, isObserved, isProbabilistic, lazyEval, observe, observeOwnValue, print, print, removeLabel, save, setAndCascade, setLabel, setLabel, setParents, setParents, setState, setValue, toString, unobserveclone, finalize, getClass, notify, notifyAll, wait, wait, waitdLogPmf, dLogPmf, dLogPmf, logPmf, logPmf, logPmfdLogProb, dLogProbAtValue, dLogProbAtValue, getValue, keepOnlyProbabilisticVertices, logProbAtValuegetObservedValue, isObserved, observe, unobservesample, sampleManyScalars, sampleManyScalarssampleWithShapepublic MultinomialVertex(long[] tensorShape,
IntegerVertex n,
DoubleVertex p)
public MultinomialVertex(IntegerVertex n, DoubleVertex p)
public MultinomialVertex(int n,
DoubleVertex p)
public MultinomialVertex(int n,
DoubleTensor p)
public MultinomialVertex(IntegerTensor n, DoubleVertex p)
public MultinomialVertex(IntegerTensor n, DoubleTensor p)
public double logProb(IntegerTensor xTensor)
ProbabilisticlogProb in interface Probabilistic<IntegerTensor>xTensor - The supplied value.public java.util.Map<Vertex,DoubleTensor> dLogProb(IntegerTensor value, java.util.Set<? extends Vertex> withRespectTo)
ProbabilisticdLogProb in interface Probabilistic<IntegerTensor>value - at a given valuewithRespectTo - list of parents to differentiate with respect topublic IntegerTensor sampleWithShape(long[] shape, KeanuRandom random)
sampleWithShape in interface SamplableWithShape<IntegerTensor>public DoubleVertex getP()
public IntegerVertex getN()