ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf.compat.v1.distributions.DirichletMultinomial

Dirichlet-Multinomial compound distribution.

Inherits From: Distribution

The Dirichlet-Multinomial distribution is parameterized by a (batch of) length-K concentration vectors (K > 1) and a total_count number of trials, i.e., the number of trials per draw from the DirichletMultinomial. It is defined over a (batch of) length-K vector counts such that tf.reduce_sum(counts, -1) = total_count. The Dirichlet-Multinomial is identically the Beta-Binomial distribution when K = 2.

Mathematical Details

The Dirichlet-Multinomial is a distribution over K-class counts, i.e., a length-K vector of non-negative integer counts = n = [n_0, ..., n_{K-1}].

The probability mass function (pmf) is,

pmf(n; alpha, N) = Beta(alpha + n) / (prod_j n_j!) / Z
Z = Beta(alpha) / N!

where:

  • concentration = alpha = [alpha_0, ..., alpha_{K-1}], alpha_j > 0,
  • total_count = N, N a positive integer,
  • N! is N factorial, and,
  • Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j) is the multivariate beta function, and,
  • Gamma is the gamma function.

Dirichlet-Multinomial is a compound distribution, i.e., its samples are generated as follows.

  1. Choose class probabilities: probs = [p_0,...,p_{K-1}] ~ Dir(concentration)
  2. Draw integers: counts = [n_0,...,n_{K-1}] ~ Multinomial(total_count, probs)

The last concentration dimension parametrizes a single Dirichlet-Multinomial distribution. When calling distribution functions (e.g., dist.prob(counts)), concentration, total_count and counts are broadcast to the same shape. The last dimension of counts corresponds single Dirichlet-Multinomial distributions.

Distribution parameters are automatically broadcast in all functions; see examples for details.

Pitfalls

The number of classes, K, must not exceed:

  • the largest integer representable by self.dtype, i.e., 2**(mantissa_bits+1) (IEE754),
  • the maximum Tensor index, i.e., 2**31-1.

In other words,

K <= min(2**31-1, {
  tf.float16: 2**11,
  tf.float32: 2**24,
  tf.float64: 2**53 }[param.dtype])

Examples

alpha = [1., 2., 3.]
n = 2.
dist = DirichletMultinomial(n, alpha)

Creates a 3-class distribution, with the 3rd class is most likely to be drawn. The distribution functions can be evaluated on counts.

# counts same shape as alpha.
counts = [0., 0., 2.]
dist.prob(counts)  # Shape []

# alpha will be broadcast to [[1., 2., 3.], [1., 2., 3.]] to match counts.
counts = [[1., 1., 0.], [1., 0., 1.]]
dist.prob(counts)  # Shape [2]

# alpha will be broadcast to shape [5, 7, 3] to match counts.
counts = [[...]]  # Shape [5, 7, 3]
dist.prob(counts)  # Shape [5, 7]

Creates a 2-batch of 3-class distributions.

alpha = [[1., 2., 3.], [4., 5., 6.]]  # Shape [2, 3]
n = [3., 3.]
dist = DirichletMultinomial(n, alpha)

# counts will be broadcast to [[2., 1., 0.], [2., 1., 0.]] to match alpha.
counts = [2., 1., 0.]
dist.prob(counts)  # Shape [2]

total_count Non-negative floating point tensor, whose dtype is the same as concentration. The shape is broadcastable to [N1,..., Nm] with