Thanks for tuning in to Google I/O. View all sessions on demandWatch on demand

Factorial Mixture

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

In this notebook we show how to use TensorFlow Probability (TFP) to sample from a factorial Mixture of Gaussians distribution defined as: \(p(x_1, ..., x_n) = \prod_i p_i(x_i)\) where: \(\begin{align*} p_i &\equiv \frac{1}{K}\sum_{k=1}^K \pi_{ik}\,\text{Normal}\left(\text{loc}=\mu_{ik},\, \text{scale}=\sigma_{ik}\right)\\1&=\sum_{k=1}^K\pi_{ik}, \forall i.\hphantom{MMMMMMMMMMM}\end{align*}\)

Each variable \(x_i\) is modeled as a mixture of Gaussians, and the joint distribution over all \(n\) variables is a product of these densities.

Given a dataset \(x^{(1)}, ..., x^{(T)}\), we model each dataponit \(x^{(j)}\) as a factorial mixture of Gaussians:

\[p(x^{(j)}) = \prod_i p_i (x_i^{(j)})\]

Factorial mixtures are a simple way of creating distributions with a small number of parameters and a large number of modes.

import tensorflow as tf
import numpy as np
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns
tfd = tfp.distributions

# Use try/except so we can easily re-execute the whole notebook.
try:
  tf.enable_eager_execution()
except:
  pass

Build the Factorial Mixture of Gaussians using TFP

num_vars = 2        # Number of variables (`n` in formula).
var_dim = 1         # Dimensionality of each variable `x[i]`.
num_components = 3  # Number of components for each mixture (`K` in formula).
sigma = 5e-2        # Fixed standard deviation of each component.

# Choose some random (component) modes.
component_mean = tfd.Uniform().sample([num_vars, num_components, var_dim])

factorial_mog = tfd.Independent(
   tfd.MixtureSameFamily(
       # Assume uniform weight on each component.
       mixture_distribution=tfd.Categorical(
           logits=tf.zeros([num_vars, num_components])),
       components_distribution=tfd.MultivariateNormalDiag(
           loc=component_mean, scale_diag=[sigma])),
   reinterpreted_batch_ndims=1)

Notice our use of tfd.Independent. This "meta-distribution" applies a reduce_sum in the log_prob calculation over the rightmost reinterpreted_batch_ndims batch dimensions. In our case, this sums out the variables dimension leaving only the batch dimension when we compute log_prob. Note that this does not affect sampling.

Plot the Density

Compute the density on a grid of points, and show the locations of the modes with red stars. Each mode in the factorial mixture corresponds to a pair of modes from the underlying individual-variable mixture of Gaussians. We can see 9 modes in the plot below, but we only needed 6 parameters (3 to specify the locations of the modes in \(x_1\), and 3 to specify the locations of the modes in \(x_2\)). In contrast, a mixture of Gaussians distribution in the 2d space \((x_1, x_2)\) would require 2 * 9 = 18 parameters to specify the 9 modes.

plt.figure(figsize=(6,5))

# Compute density.
nx = 250 # Number of bins per dimension.
x = np.linspace(-3 * sigma, 1 + 3 * sigma, nx).astype('float32')
vals = tf.reshape(tf.stack(np.meshgrid(x, x), axis=2), (-1, num_vars, var_dim))
probs = factorial_mog.prob(vals).numpy().reshape(nx, nx)

# Display as image.
from matplotlib.colors import ListedColormap
cmap = ListedColormap(sns.color_palette("Blues", 256))
p = plt.pcolor(x, x, probs, cmap=cmap)
ax = plt.axis('tight');

# Plot locations of means.
means_np = component_mean.numpy().squeeze()
for mu_x in means_np[0]:
  for mu_y in means_np[1]:
    plt.scatter(mu_x, mu_y, s=150, marker='*', c='r', edgecolor='none');
plt.axis(ax);

plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.title('Density of factorial mixture of Gaussians');

png

Plot samples and marginal density estimates

samples = factorial_mog.sample(1000).numpy()

g = sns.jointplot(
    x=samples[:, 0, 0],
    y=samples[:, 1, 0],
    kind="scatter",
    marginal_kws=dict(bins=50))
g.set_axis_labels("$x_1$", "$x_2$");

png