Factorial Mixture

In this notebook we show how to use TensorFlow Probability (TFP) to sample from a factorial Mixture of Gaussians distribution defined as: p(x1,...,xn)=ipi(xi) where: pi1Kk=1KπikNormal(loc=μik,scale=σik)1=k=1Kπik,i.MMMMMMMMMMM

Each variable xi is modeled as a mixture of Gaussians, and the joint distribution over all n variables is a product of these densities.

Given a dataset x(1),...,x(T), we model each dataponit x(j) as a factorial mixture of Gaussians:

p(x(j))=ipi(xi(j))

Factorial mixtures are a simple way of creating distributions with a small number of parameters and a large number of modes.

import tensorflow as tf
import numpy as np
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns
tfd = tfp.distributions

# Use try/except so we can easily re-execute the whole notebook.
try:
  tf.enable_eager_execution()
except:
  pass

Build the Factorial Mixture of Gaussians using TFP

num_vars = 2        # Number of variables (`n` in formula).
var_dim = 1         # Dimensionality of each variable `x[i]`.
num_components = 3  # Number of components for each mixture (`K` in formula).
sigma = 5e-2        # Fixed standard deviation of each component.

# Choose some random (component) modes.
component_mean = tfd.Uniform().sample([num_vars, num_components, var_dim])

factorial_mog = tfd.Independent(
   tfd.MixtureSameFamily(
       # Assume uniform weight on each component.
       mixture_distribution=tfd.Categorical(
           logits=tf.zeros([num_vars, num_components])),
       components_distribution=tfd.MultivariateNormalDiag(
           loc=component_mean, scale_diag=[sigma])),
   reinterpreted_batch_ndims=1)

Notice our use of tfd.Independent. This "meta-distribution" applies a reduce_sum in the log_prob calculation over the rightmost reinterpreted_batch_ndims batch dimensions. In our case, this sums out the variables dimension leaving only the batch dimension when we compute log_prob. Note that this does not affect sampling.

Plot the Density

Compute the density on a grid of points, and show the locations of the modes with red stars. Each mode in the factorial mixture corresponds to a pair of modes from the underlying individual-variable mixture of Gaussians. We can see 9 modes in the plot below, but we only needed 6 parameters (3 to specify the locations of the modes in x1, and 3 to specify the locations of the modes in x2). In contrast, a mixture of Gaussians distribution in the 2d space (x1,x2) would require 2 * 9 = 18 parameters to specify the 9 modes.

plt.figure(figsize=(6,5))

# Compute density.
nx = 250 # Number of bins per dimension.
x = np.linspace(-3 * sigma, 1 + 3 * sigma, nx).astype('float32')
vals = tf.reshape(tf.stack(np.meshgrid(x, x), axis=2), (-1, num_vars, var_dim))
probs = factorial_mog.prob(vals).numpy().reshape(nx, nx)

# Display as image.
from matplotlib.colors import ListedColormap
cmap = ListedColormap(sns.color_palette("Blues", 256))
p = plt.pcolor(x, x, probs, cmap=cmap)
ax = plt.axis('tight');

# Plot locations of means.
means_np = component_mean.numpy().squeeze()
for mu_x in means_np[0]:
  for mu_y in means_np[1]:
    plt.scatter(mu_x, mu_y, s=150, marker='*', c='r', edgecolor='none');
plt.axis(ax);

plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.title('Density of factorial mixture of Gaussians');

png

Plot samples and marginal density estimates

samples = factorial_mog.sample(1000).numpy()

g = sns.jointplot(
    x=samples[:, 0, 0],
    y=samples[:, 1, 0],
    kind="scatter",
    marginal_kws=dict(bins=50))
g.set_axis_labels("$x_1$", "$x_2$");

png