ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf.compat.v1.distributions.Multinomial

Multinomial distribution.

Inherits From: Distribution

This Multinomial distribution is parameterized by probs, a (batch of) length-K prob (probability) vectors (K > 1) such that tf.reduce_sum(probs, -1) = 1, and a total_count number of trials, i.e., the number of trials per draw from the Multinomial. It is defined over a (batch of) length-K vector counts such that tf.reduce_sum(counts, -1) = total_count. The Multinomial is identically the Binomial distribution when K = 2.

Mathematical Details

The Multinomial is a distribution over K-class counts, i.e., a length-K vector of non-negative integer counts = n = [n_0, ..., n_{K-1}].

The probability mass function (pmf) is,

pmf(n; pi, N) = prod_j (pi_j)**n_j / Z
Z = (prod_j n_j!) / N!

where:

  • probs = pi = [pi_0, ..., pi_{K-1}], pi_j > 0, sum_j pi_j = 1,
  • total_count = N, N a positive integer,
  • Z is the normalization constant, and,
  • N! denotes N factorial.

Distribution parameters are automatically broadcast in all functions; see examples for details.

Pitfalls

The number of classes, K, must not exceed:

  • the largest integer representable by self.dtype, i.e., 2**(mantissa_bits+1) (IEE754),
  • the maximum Tensor index, i.e., 2**31-1.

In other words,

K <= min(2**31-1, {
  tf.float16: 2**11,
  tf.float32: 2**24,
  tf.float64: 2**53 }[param.dtype])

Examples

Create a 3-class distribution, with the 3rd class is most likely to be drawn, using logits.

logits = [-50., -43, 0]
dist = Multinomial(total_count=4., logits=logits)

Create a 3-class distribution, with the 3rd class is most likely to be drawn.

p = [.2, .3, .5]
dist = Multinomial(total_count=4., probs=p)

The distribution functions can be evaluated on counts.

# counts same shape as p.
counts = [1., 0, 3]
dist.prob(counts)  # Shape []

# p will be broadcast to [[.2, .3, .5], [.2, .3, .5]] to match counts.
counts = [[1., 2, 1], [2, 2, 0]]
dist.prob(counts)  # Shape [2]

# p will be broadcast to shape [5, 7, 3] to match counts.
counts = [[...]]  # Shape [5, 7, 3]
dist.prob(counts)  # Shape [5, 7]

Create a 2-batch of 3-class distributions.

p = [[.1, .2, .7], [.3, .3, .4]]  # Shape [2, 3]
dist = Multinomial(total_count=[4., 5], probs=p)

counts = [[2., 1, 1], [3, 1, 1]]
dist.prob(counts)  # Shape [2]

dist.sample(5) # Shape [5, 2, 3]

total_count Non-negative floating point tensor with shape broadcastable to [N1,..., Nm] with m >= 0. Defines this as a batch of N1 x ... x Nm different Multinomial distributions. Its components should be equal to integer values.
logits Floating point tensor representing unnormalized log-probabilities of a positive event with shape broadcastable to [N1,..., Nm, K] m >= 0, and the same dtype as total_count. Defines this as a batch of N1 x ... x Nm different K class Multinomial distributions. Only one of logits or probs should be passed in.
probs Positive floating point tensor with shape broadcastable to [N1,..., Nm, K] m >= 0 and same dtype as total_count. Defines this as a batch of N1 x ... x Nm different K class Multinomial distributions. probs's components in the last portion of its shape should sum to 1. Only one of logits or probs should be passed in.
validate_args Python bool, default False. When True distribution parameters are checked for validity despite possibly degrading runtime performance. When