Bayesian Gaussian Mixture Model and Hamiltonian MCMC

View on Run in Google Colab View source on GitHub Download notebook

In this colab we'll explore sampling from the posterior of a Bayesian Gaussian Mixture Model (BGMM) using only TensorFlow Probability primitives.


For \(k\in\{1,\ldots, K\}\) mixture components each of dimension \(D\), we'd like to model \(i\in\{1,\ldots,N\}\) iid samples using the following Bayesian Gaussian Mixture Model:

\[\begin{align*} \theta &\sim \text{Dirichlet}(\text{concentration}=\alpha_0)\\ \mu_k &\sim \text{Normal}(\text{loc}=\mu_{0k}, \text{scale}=I_D)\\ T_k &\sim \text{Wishart}(\text{df}=5, \text{scale}=I_D)\\ Z_i &\sim \text{Categorical}(\text{probs}=\theta)\\ Y_i &\sim \text{Normal}(\text{loc}=\mu_{z_i}, \text{scale}=T_{z_i}^{-1/2})\\ \end{align*}\]

Note, the scale arguments all have cholesky semantics. We use this convention because it is that of TF Distributions (which itself uses this convention in part because it is computationally advantageous).

Our goal is to generate samples from the posterior:

\[p\left(\theta, \{\mu_k, T_k\}_{k=1}^K \Big| \{y_i\}_{i=1}^N, \alpha_0, \{\mu_{ok}\}_{k=1}^K\right)\]

Notice that \(\{Z_i\}_{i=1}^N\) is not present--we're interested in only those random variables which don't scale with \(N\). (And luckily there's a TF distribution which handles marginalizing out \(Z_i\).)

It is not possible to directly sample from this distribution owing to a computationally intractable normalization term.

Metropolis-Hastings algorithms are techniques for sampling from intractable-to-normalize distributions.

TensorFlow Probability offers a number of MCMC options, including several based on Metropolis-Hastings. In this notebook, we'll use Hamiltonian Monte Carlo (tfp.mcmc.HamiltonianMonteCarlo). HMC is often a good choice because it can converge rapidly, samples the state space jointly (as opposed to coordinatewise), and leverages one of TF's virtues: automatic differentiation. That said, sampling from a BGMM posterior might actually be better done by other approaches, e.g., Gibb's sampling.

%matplotlib inline

import functools

import matplotlib.pyplot as plt;'ggplot')
import numpy as np
import seaborn as sns; sns.set_context('notebook')

import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp

tfd = tfp.distributions
tfb = tfp.bijectors

physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
  tf.config.experimental.set_memory_growth(physical_devices[0], True)

Before actually building the model, we'll need to define a new type of distribution. From the model specification above, it's clear we're parameterizing the MVN with an inverse covariance matrix, i.e., [precision matrix]( To accomplish this in TF, we'll need to roll out our Bijector. This Bijector will use the forward transformation:

  • Y = tf.linalg.triangular_solve((tf.linalg.matrix_transpose(chol_precision_tril), X, adjoint=True) + loc.

And the log_prob calculation is just the inverse, i.e.:

  • X = tf.linalg.matmul(chol_precision_tril, X - loc, adjoint_a=True).

Since all we need for HMC is log_prob, this means we avoid ever calling tf.linalg.triangular_solve (as would be the case for tfd.MultivariateNormalTriL). This is advantageous since tf.linalg.matmul is usually faster owing to better cache locality.

class MVNCholPrecisionTriL(tfd.TransformedDistribution):
  """MVN from loc and (Cholesky) precision matrix."""

  def __init__(self, loc, chol_precision_tril, name=None):
    super(MVNCholPrecisionTriL, self).__init__(

The tfd.Independent distribution turns independent draws of one distribution, into a multivariate distribution with statistically independent coordinates. In terms of computing log_prob, this "meta-distribution" manifests as a simple sum over the event dimension(s).

Also notice that we took the adjoint ("transpose") of the scale matrix. This is because if precision is inverse covariance, i.e., \(P=C^{-1}\) and if \(C=AA^\top\), then \(P=BB^{\top}\) where \(B=A^{-\top}\).

Since this distribution is kind of tricky, let's quickly verify that our MVNCholPrecisionTriL works as we think it should.

def compute_sample_stats(d, seed=42, n=int(1e6)):
  x = d.sample(n, seed=seed)
  sample_mean = tf.reduce_mean(x, axis=0, keepdims=True)
  s = x - sample_mean
  sample_cov = tf.linalg.matmul(s, s, adjoint_a=True) / tf.cast(n, s.dtype)
  sample_scale = tf.linalg.cholesky(sample_cov)
  sample_mean = sample_mean[0]
  return [

dtype = np.float32
true_loc = np.array([1., -1.], dtype=dtype)
true_chol_precision = np.array([[1., 0.],
                                [2., 8.]],
true_precision = np.matmul(true_chol_precision, true_chol_precision.T)
true_cov = np.linalg.inv(true_precision)

d = MVNCholPrecisionTriL(

[sample_mean, sample_cov, sample_scale] = [
    t.numpy() for t in compute_sample_stats(d)]

print('true mean:', true_loc)
print('sample mean:', sample_mean)
print('true cov:\n', true_cov)
print('sample cov:\n', sample_cov)
true mean: [ 1. -1.]
sample mean: [ 1.0002806 -1.000105 ]
true cov:
 [[ 1.0625   -0.03125 ]
 [-0.03125   0.015625]]
sample cov:
 [[ 1.0641273  -0.03126175]
 [-0.03126175  0.01559312]]

Since the sample mean and covariance are close to the true mean and covariance, it seems like the distribution is correctly implemented. Now, we'll use MVNCholPrecisionTriL tfp.distributions.JointDistributionNamed to specify the BGMM model. For the observational model, we'll use tfd.MixtureSameFamily to automatically integrate out the \(\{Z_i\}_{i=1}^N\) draws.

dtype = np.float64
dims = 2
components = 3
num_samples = 1000
bgmm = tfd.JointDistributionNamed(dict(
    concentration=np.ones(components, dtype) / 10.),
            -np.ones(dims, dtype),
            np.zeros(dims, dtype),
            np.ones(dims, dtype),
        scale=tf.ones([components, dims], dtype)),
        scale_tril=np.stack([np.eye(dims, dtype=dtype)]*components),
  s=lambda mix_probs, loc, precision: tfd.Sample(tfd.MixtureSameFamily(
def joint_log_prob(observations, mix_probs, loc, chol_precision):
  """BGMM with priors: loc=Normal, precision=Inverse-Wishart, mix=Dirichlet.

    observations: `[n, d]`-shaped `Tensor` representing Bayesian Gaussian
      Mixture model draws. Each sample is a length-`d` vector.
    mix_probs: `[K]`-shaped `Tensor` representing random draw from
      `Dirichlet` prior.
    loc: `[K, d]`-shaped `Tensor` representing the location parameter of the
      `K` components.
    chol_precision: `[K, d, d]`-shaped `Tensor` representing `K` lower
      triangular `cholesky(Precision)` matrices, each being sampled from
      a Wishart distribution.

    log_prob: `Tensor` representing joint log-density over all inputs.
  return bgmm.log_prob(
      mix_probs=mix_probs, loc=loc, precision=chol_precision, s=observations)

Generate "Training" Data

For this demo, we'll sample some random data.

true_loc = np.array([[-2., -2],
                     [0, 0],
                     [2, 2]], dtype)
random = np.random.RandomState(seed=43)

true_hidden_component = random.randint(0, components, num_samples)
observations = (true_loc[true_hidden_component] +
                random.randn(num_samples, dims).astype(dtype))

Bayesian Inference using HMC

Now that we've used TFD to specify our model and obtained some observed data, we have all the necessary pieces to run HMC.

To do this, we'll use a partial application to "pin down" the things we don't want to sample. In this case that means we need only pin down observations. (The hyper-parameters are already baked in to the prior distributions and not part of the joint_log_prob function signature.)

unnormalized_posterior_log_prob = functools.partial(joint_log_prob, observations)
initial_state = [
            value=np.array(1. / components, dtype),
    tf.constant(np.array([[-2., -2],
                          [0, 0],
                          [2, 2]], dtype),
    tf.linalg.eye(dims, batch_shape=[components], dtype=dtype, name='chol_precision'),

Unconstrained Representation

Hamiltonian Monte Carlo (HMC) requires the target log-probability function to be differentiable with respect to its arguments. Furthermore, HMC can exhibit dramatically higher statistical efficiency if the state-space is unconstrained.

This means we'll have to work out two main issues when sampling from the BGMM posterior:

  1. \(\theta\) represents a discrete probability vector, i.e., must be such that \(\sum_{k=1}^K \theta_k = 1\) and \(\theta_k>0\).
  2. \(T_k\) represents an inverse covariance matrix, i.e., must be such that \(T_k \succ 0\), i.e., is positive definite.

To address this requirement we'll need to:

  1. transform the constrained variables to an unconstrained space
  2. run the MCMC in unconstrained space
  3. transform the unconstrained variables back to the constrained space.

As with MVNCholPrecisionTriL, we'll use Bijectors to transform random variables to unconstrained space.

  • The Dirichlet is transformed to unconstrained space via the softmax function.

  • Our precision random variable is a distribution over positive semidefinite matrices. To unconstrain these we'll use the FillTriangular and TransformDiagonal bijectors. These convert vectors to lower-triangular matrices and ensure the diagonal is positive. The former is useful because it enables sampling only \(d(d+1)/2\) floats rather than \(d^2\).

unconstraining_bijectors = [
def sample():
  return tfp.mcmc.sample_chain(
    trace_fn=lambda _, pkr: pkr.inner_results.inner_results.is_accepted)

[mix_probs, loc, chol_precision], is_accepted = sample()

We'll now execute the chain and print the posterior means.

acceptance_rate = tf.reduce_mean(tf.cast(is_accepted, dtype=tf.float32)).numpy()
mean_mix_probs = tf.reduce_mean(mix_probs, axis=0).numpy()
mean_loc = tf.reduce_mean(loc, axis=0).numpy()
mean_chol_precision = tf.reduce_mean(chol_precision, axis=0).numpy()
precision = tf.linalg.matmul(chol_precision, chol_precision, transpose_b=True)
print('acceptance_rate:', acceptance_rate)
print('avg mix probs:', mean_mix_probs)
print('avg loc:\n', mean_loc)
print('avg chol(precision):\n', mean_chol_precision)
acceptance_rate: 0.5305
avg mix probs: [0.25248723 0.60729516 0.1402176 ]
avg loc:
 [[-1.96466753 -2.12047249]
 [ 0.27628865  0.22944732]
 [ 2.06461244  2.54216122]]
avg chol(precision):
 [[[ 1.05105032  0.        ]
  [ 0.12699955  1.06553113]]

 [[ 0.76058015  0.        ]
  [-0.50332767  0.77947431]]

 [[ 1.22770457  0.        ]
  [ 0.70670027  1.50914164]]]
loc_ = loc.numpy()
ax = sns.kdeplot(loc_[:,0,0], loc_[:,0,1], shade=True, shade_lowest=False)
ax = sns.kdeplot(loc_[:,1,0], loc_[:,1,1], shade=True, shade_lowest=False)
ax = sns.kdeplot(loc_[:,2,0], loc_[:,2,1], shade=True, shade_lowest=False)
plt.title('KDE of loc draws');



This simple colab demonstrated how TensorFlow Probability primitives can be used to build hierarchical Bayesian mixture models.