FJORD

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Configuración

Primero instale los paquetes utilizados en esta demostración.

pip install -q dm-sonnet

Importaciones (tf, tfp con truco adjunto, etc.)

import numpy as np
import tqdm as tqdm
import sklearn.datasets as skd

# visualization
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import kde

# tf and friends
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
import sonnet as snt
tf.enable_v2_behavior()

tfb = tfp.bijectors
tfd = tfp.distributions

def make_grid(xmin, xmax, ymin, ymax, gridlines, pts):
  xpts = np.linspace(xmin, xmax, pts)
  ypts = np.linspace(ymin, ymax, pts)
  xgrid = np.linspace(xmin, xmax, gridlines)
  ygrid = np.linspace(ymin, ymax, gridlines)
  xlines = np.stack([a.ravel() for a in np.meshgrid(xpts, ygrid)])
  ylines = np.stack([a.ravel() for a in np.meshgrid(xgrid, ypts)])
  return np.concatenate([xlines, ylines], 1).T

grid = make_grid(-3, 3, -3, 3, 4, 100)
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

Funciones de ayuda para la visualización

Biyector FFJORD

En este laboratorio colaborativo demostramos el biyector FFJORD, propuesto originalmente en el artículo de Grathwohl, Will, et al. arXiv enlace .

En el resumen de la idea detrás de este enfoque es el de establecer una correspondencia entre una distribución de base conocida y la distribución de datos.

Para establecer esta conexión, necesitamos

  1. Definir un biyectiva mapa \(\mathcal{T}_{\theta}:\mathbf{x} \rightarrow \mathbf{y}\), \(\mathcal{T}_{\theta}^{1}:\mathbf{y} \rightarrow \mathbf{x}\) entre el espacio \(\mathcal{Y}\) en la que se define la distribución de base y el espacio \(\mathcal{X}\) del dominio de datos.
  2. De manera eficiente un seguimiento de las deformaciones que llevamos a cabo para transferir la noción de probabilidad en \(\mathcal{X}\).

La segunda condición se formaliza en la siguiente expresión para la distribución de probabilidad definida en \(\mathcal{X}\):

\[ \log p_{\mathbf{x} }(\mathbf{x})=\log p_{\mathbf{y} }(\mathbf{y})-\log \operatorname{det}\left|\frac{\partial \mathcal{T}_{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| \]

El biyector FFJORD logra esto definiendo una transformación

\[ \mathcal{T_{\theta} }: \mathbf{x} = \mathbf{z}(t_{0}) \rightarrow \mathbf{y} = \mathbf{z}(t_{1}) \quad : \quad \frac{d \mathbf{z} }{dt} = \mathbf{f}(t, \mathbf{z}, \theta) \]

Esta transformación es invertible, siempre y cuando la función \(\mathbf{f}\) que describe la evolución del estado \(\mathbf{z}\) se comporta bien y el log_det_jacobian puede ser calculado mediante la integración de la siguiente expresión.

\[ \log \operatorname{det}\left|\frac{\partial \mathcal{T}_{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| = -\int_{t_{0} }^{t_{1} } \operatorname{Tr}\left(\frac{\partial \mathbf{f}(t, \mathbf{z}, \theta)}{\partial \mathbf{z}(t)}\right) d t \]

En esta demo vamos a entrenar a un bijector FFJORD para deformar una distribución de Gauss sobre la distribución definida por moons conjunto de datos. Esto se hará en 3 pasos:

  • Definir distribución de base
  • Definir biyector FFJORD
  • Minimizar la probabilidad de registro exacta del conjunto de datos

Primero, cargamos los datos

Conjunto de datos

png

A continuación, instanciamos una distribución base

base_loc = np.array([0.0, 0.0]).astype(np.float32)
base_sigma = np.array([0.8, 0.8]).astype(np.float32)
base_distribution = tfd.MultivariateNormalDiag(base_loc, base_sigma)

Utilizamos un perceptrón multicapa con el modelo state_derivative_fn .

Aunque no es necesario para este conjunto de datos, a menudo es beneficiosa para hacer state_derivative_fn depende del tiempo. Aquí logramos esto mediante la concatenación de t a entradas de nuestra red.

class MLP_ODE(snt.Module):
  """Multi-layer NN ode_fn."""
  def __init__(self, num_hidden, num_layers, num_output, name='mlp_ode'):
    super(MLP_ODE, self).__init__(name=name)
    self._num_hidden = num_hidden
    self._num_output = num_output
    self._num_layers = num_layers
    self._modules = []
    for _ in range(self._num_layers - 1):
      self._modules.append(snt.Linear(self._num_hidden))
      self._modules.append(tf.math.tanh)
    self._modules.append(snt.Linear(self._num_output))
    self._model = snt.Sequential(self._modules)

  def __call__(self, t, inputs):
    inputs = tf.concat([tf.broadcast_to(t, inputs.shape), inputs], -1)
    return self._model(inputs)

Parámetros de modelo y entrenamiento

Ahora construimos una pila de biyectores FFJORD. Cada bijector se proporciona con ode_solve_fn y trace_augmentation_fn y su propio state_derivative_fn modelo, por lo que representan una secuencia de diferentes transformaciones.

Edificio biyector

Ahora podemos usar TransformedDistribution que es el resultado de la deformación base_distribution con stacked_ffjord bijector.

transformed_distribution = tfd.TransformedDistribution(
    distribution=base_distribution, bijector=stacked_ffjord)

Ahora definimos nuestro procedimiento de entrenamiento. Simplemente minimizamos la probabilidad logarítmica negativa de los datos.

Capacitación

Muestras

Trazar muestras de distribuciones base y transformadas.

evaluation_samples = []
base_samples, transformed_samples = get_samples()
transformed_grid = get_transformed_grid()
evaluation_samples.append((base_samples, transformed_samples, transformed_grid))
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
panel_id = 0
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
  1, 4, figsize=(16, 6))
plot_panel(
    grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray, False)
plt.tight_layout()

png

learning_rate = tf.Variable(LR, trainable=False)
optimizer = snt.optimizers.Adam(learning_rate)

for epoch in tqdm.trange(NUM_EPOCHS // 2):
  base_samples, transformed_samples = get_samples()
  transformed_grid = get_transformed_grid()
  evaluation_samples.append(
      (base_samples, transformed_samples, transformed_grid))
  for batch in moons_ds:
    _ = train_step(optimizer, batch)
0%|          | 0/40 [00:00<?, ?it/s]
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/math/ode/base.py:350: calling while_loop_v2 (from tensorflow.python.ops.control_flow_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))
100%|██████████| 40/40 [07:00<00:00, 10.52s/it]
panel_id = -1
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
  1, 4, figsize=(16, 6))
plot_panel(grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray)
plt.tight_layout()

png

Entrenarlo durante más tiempo con la tasa de aprendizaje da como resultado mejoras adicionales.

No incluido en este ejemplo, el biyector FFJORD admite la estimación de trazas estocásticas de Hutchinson. El estimador particular puede ser proporcionado a través de trace_augmentation_fn . Del mismo modo integradores alternativas se pueden utilizar mediante la definición de encargo ode_solve_fn .