Mélange factoriel

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Dans ce cahier , nous montrons comment utiliser tensorflow probabilité (TFP) à partir d' un mélange échantillon factoriel de la distribution gaussienne définie comme:\(p(x_1, ..., x_n) = \prod_i p_i(x_i)\) où: \(\begin{align*} p_i &\equiv \frac{1}{K}\sum_{k=1}^K \pi_{ik}\,\text{Normal}\left(\text{loc}=\mu_{ik},\, \text{scale}=\sigma_{ik}\right)\\1&=\sum_{k=1}^K\pi_{ik}, \forall i.\hphantom{MMMMMMMMMMM}\end{align*}\)

Chaque variable \(x_i\) est modélisé sous la forme d' un mélange de gaussiennes, et la distribution conjointe sur l' ensemble \(n\) des variables est un produit de ces densités.

Étant donné un ensemble de données \(x^{(1)}, ..., x^{(T)}\), nous modélisons chaque dataponit \(x^{(j)}\) comme un mélange factoriel de gaussiennes:

\[p(x^{(j)}) = \prod_i p_i (x_i^{(j)})\]

Les mélanges factoriels sont un moyen simple de créer des distributions avec un petit nombre de paramètres et un grand nombre de modes.

import tensorflow as tf
import numpy as np
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns
tfd = tfp.distributions

# Use try/except so we can easily re-execute the whole notebook.
try:
  tf.enable_eager_execution()
except:
  pass

Construire le mélange factoriel de gaussiennes à l'aide de la TFP

num_vars = 2        # Number of variables (`n` in formula).
var_dim = 1         # Dimensionality of each variable `x[i]`.
num_components = 3  # Number of components for each mixture (`K` in formula).
sigma = 5e-2        # Fixed standard deviation of each component.

# Choose some random (component) modes.
component_mean = tfd.Uniform().sample([num_vars, num_components, var_dim])

factorial_mog = tfd.Independent(
   tfd.MixtureSameFamily(
       # Assume uniform weight on each component.
       mixture_distribution=tfd.Categorical(
           logits=tf.zeros([num_vars, num_components])),
       components_distribution=tfd.MultivariateNormalDiag(
           loc=component_mean, scale_diag=[sigma])),
   reinterpreted_batch_ndims=1)

Notres utilisation de tfd.Independent . Cette « méta-distribution » applique une reduce_sum dans le log_prob calcul sur les plus à droite reinterpreted_batch_ndims dimensions des lots. Dans notre cas, cette somme sur les variables dimension ne laissant que la dimension du lot quand on calcule log_prob . Notez que cela n'affecte pas l'échantillonnage.

Tracer la densité

Calculez la densité sur une grille de points et montrez les emplacements des modes avec des étoiles rouges. Chaque mode du mélange factoriel correspond à une paire de modes du mélange sous-jacent de variables individuelles de gaussiennes. Nous pouvons voir 9 modes dans le tracé ci - dessous, mais nous avons seulement besoin de 6 paramètres (3 pour spécifier les emplacements des modes de \(x_1\)et 3 pour spécifier l'emplacement des modes de \(x_2\)). En revanche, un mélange de la distribution gaussienne dans l'espace 2d \((x_1, x_2)\) nécessiterait 2 * 9 = 18 paramètres pour spécifier les 9 modes.

plt.figure(figsize=(6,5))

# Compute density.
nx = 250 # Number of bins per dimension.
x = np.linspace(-3 * sigma, 1 + 3 * sigma, nx).astype('float32')
vals = tf.reshape(tf.stack(np.meshgrid(x, x), axis=2), (-1, num_vars, var_dim))
probs = factorial_mog.prob(vals).numpy().reshape(nx, nx)

# Display as image.
from matplotlib.colors import ListedColormap
cmap = ListedColormap(sns.color_palette("Blues", 256))
p = plt.pcolor(x, x, probs, cmap=cmap)
ax = plt.axis('tight');

# Plot locations of means.
means_np = component_mean.numpy().squeeze()
for mu_x in means_np[0]:
  for mu_y in means_np[1]:
    plt.scatter(mu_x, mu_y, s=150, marker='*', c='r', edgecolor='none');
plt.axis(ax);

plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.title('Density of factorial mixture of Gaussians');

png

Tracer des échantillons et des estimations de densité marginale

samples = factorial_mog.sample(1000).numpy()

g = sns.jointplot(
    x=samples[:, 0, 0],
    y=samples[:, 1, 0],
    kind="scatter",
    marginal_kws=dict(bins=50))
g.set_axis_labels("$x_1$", "$x_2$");

png