ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf.compat.v1.distributions.Dirichlet

Dirichlet distribution.

Inherits From: Distribution

The Dirichlet distribution is defined over the (k-1)-simplex using a positive, length-k vector concentration (k > 1). The Dirichlet is identically the Beta distribution when k = 2.

Mathematical Details

The Dirichlet is a distribution over the open (k-1)-simplex, i.e.,

S^{k-1} = { (x_0, ..., x_{k-1}) in R^k : sum_j x_j = 1 and all_j x_j > 0 }.

The probability density function (pdf) is,

pdf(x; alpha) = prod_j x_j**(alpha_j - 1) / Z
Z = prod_j Gamma(alpha_j) / Gamma(sum_j alpha_j)

where:

  • x in S^{k-1}, i.e., the (k-1)-simplex,
  • concentration = alpha = [alpha_0, ..., alpha_{k-1}], alpha_j > 0,
  • Z is the normalization constant aka the multivariate beta function, and,
  • Gamma is the gamma function.

The concentration represents mean total counts of class occurrence, i.e.,

concentration = alpha = mean * total_concentration

where mean in S^{k-1} and total_concentration is a positive real number representing a mean total count.

Distribution parameters are automatically broadcast in all functions; see examples for details.

Samples of this distribution are reparameterized (pathwise differentiable). The derivatives are computed using the approach described in (Figurnov et al., 2018).

Examples

import tensorflow_probability as tfp
tfd = tfp.distributions

# Create a single trivariate Dirichlet, with the 3rd class being three times
# more frequent than the first. I.e., batch_shape=[], event_shape=[3].
alpha = [1., 2, 3]
dist = tfd.Dirichlet(alpha)

dist.sample([4, 5])  # shape: [4, 5, 3]

# x has one sample, one batch, three classes:
x = [.2, .3, .5]   # shape: [3]
dist.prob(x)       # shape: []

# x has two samples from one batch:
x = [[.1, .4, .5],
     [.2, .3, .5]]
dist.prob(x)         # shape: [2]

# alpha will be broadcast to shape [5, 7, 3] to match x.
x = [[...]]   # shape: [5, 7, 3]
dist.prob(x)  # shape: [5, 7]
# Create batch_shape=[2], event_shape=[3]:
alpha = [[1., 2, 3],
         [4, 5, 6]]   # shape: [2, 3]
dist = tfd.Dirichlet(alpha)

dist.sample([4, 5])  # shape: [4, 5, 2, 3]

x = [.2, .3, .5]
# x will be broadcast as [[.2, .3, .5],
#                         [.2, .3, .5]],
# thus matching batch_shape [2, 3].
dist.prob(x)         # shape: [2]

Compute the gradients of samples w.r.t. the parameters:

alpha = tf.constant([1.0, 2.0, 3.0])
dist = tfd.Dirichlet(alpha)
samples = dist.sample(5)  # Shape [5, 3]
loss = tf.reduce_mean(tf.square(samples))  # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
grads = tf.gradients(loss, alpha)

References:

Implicit Reparameterization Gradients: Figurnov et al., 2018 (pdf)

concentration Positive floating-point Tensor indicating mean number of class occurrences; aka "alpha". Implies self.dtype, and self.batch_shape, self.event_shape, i.e., if concentration.shape = [N1, N2, ..., Nm, k] then batch_shape = [N1, N2, ..., Nm] and event_shape = [k].
validate_args Python bool, default False. When True distribution parameters are checked for validity despite possibly degrading runtime performance. When False invalid inputs may silently render incorrect outputs.
allow_nan_stats Python bool, default True. When True, statistics (e.g., mean, mode, variance) use the value "NaN" to indicate the result is undefined. When False, an exception is raised if one or more of the statistic's batch members are undefined.
name Python str