FFJORD

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Mempersiapkan

Instal paket pertama yang digunakan dalam demo ini.

pip install -q dm-sonnet

Impor (tf, tfp dengan trik adjoint, dll)

import numpy as np
import tqdm as tqdm
import sklearn.datasets as skd

# visualization
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import kde

# tf and friends
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
import sonnet as snt
tf.enable_v2_behavior()

tfb = tfp.bijectors
tfd = tfp.distributions

def make_grid(xmin, xmax, ymin, ymax, gridlines, pts):
  xpts = np.linspace(xmin, xmax, pts)
  ypts = np.linspace(ymin, ymax, pts)
  xgrid = np.linspace(xmin, xmax, gridlines)
  ygrid = np.linspace(ymin, ymax, gridlines)
  xlines = np.stack([a.ravel() for a in np.meshgrid(xpts, ygrid)])
  ylines = np.stack([a.ravel() for a in np.meshgrid(xgrid, ypts)])
  return np.concatenate([xlines, ylines], 1).T

grid = make_grid(-3, 3, -3, 3, 4, 100)
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

Fungsi pembantu untuk visualisasi

Bijektor FFJORD

Dalam colab ini kami mendemonstrasikan bijektor FFJORD, yang awalnya diusulkan dalam makalah oleh Grathwohl, Will, et al. arXiv Link .

Singkatnya itu ide di balik pendekatan tersebut adalah untuk membangun korespondensi antara distribusi dasar dikenal dan distribusi data.

Untuk membuat koneksi ini, kita perlu

  1. Tentukan bijective peta \(\mathcal{T}_{\theta}:\mathbf{x} \rightarrow \mathbf{y}\), \(\mathcal{T}_{\theta}^{1}:\mathbf{y} \rightarrow \mathbf{x}\) antara ruang \(\mathcal{Y}\) yang distribusi dasar didefinisikan dan ruang \(\mathcal{X}\) dari domain data.
  2. Efisien melacak deformasi kita melakukan untuk mentransfer gagasan probabilitas ke \(\mathcal{X}\).

Kondisi kedua adalah diformalkan dalam ekspresi berikut untuk distribusi probabilitas didefinisikan pada \(\mathcal{X}\):

\[ \log p_{\mathbf{x} }(\mathbf{x})=\log p_{\mathbf{y} }(\mathbf{y})-\log \operatorname{det}\left|\frac{\partial \mathcal{T}_{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| \]

Bijektor FFJORD menyelesaikan ini dengan mendefinisikan transformasi

\[ \mathcal{T_{\theta} }: \mathbf{x} = \mathbf{z}(t_{0}) \rightarrow \mathbf{y} = \mathbf{z}(t_{1}) \quad : \quad \frac{d \mathbf{z} }{dt} = \mathbf{f}(t, \mathbf{z}, \theta) \]

Transformasi ini dapat dibalik, selama fungsi \(\mathbf{f}\) menggambarkan evolusi dari negara \(\mathbf{z}\) berperilaku baik dan log_det_jacobian dapat dihitung dengan mengintegrasikan ekspresi berikut.

\[ \log \operatorname{det}\left|\frac{\partial \mathcal{T}_{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| = -\int_{t_{0} }^{t_{1} } \operatorname{Tr}\left(\frac{\partial \mathbf{f}(t, \mathbf{z}, \theta)}{\partial \mathbf{z}(t)}\right) d t \]

Dalam demo ini kita akan melatih bijector FFJORD untuk warp distribusi gaussian ke distribusi didefinisikan oleh moons dataset. Ini akan dilakukan dalam 3 langkah:

  • Tentukan distribusi dasar
  • Tentukan bijektor FFJORD
  • Minimalkan kemungkinan log yang tepat dari kumpulan data

Pertama, kita memuat data

Himpunan data

png

Selanjutnya, kami membuat instance distribusi dasar

base_loc = np.array([0.0, 0.0]).astype(np.float32)
base_sigma = np.array([0.8, 0.8]).astype(np.float32)
base_distribution = tfd.MultivariateNormalDiag(base_loc, base_sigma)

Kami menggunakan multi-layer perceptron model state_derivative_fn .

Meskipun tidak diperlukan untuk dataset ini, sering benefitial untuk membuat state_derivative_fn tergantung pada waktu. Di sini kita mencapai hal ini dengan menggabungkan t untuk masukan dari jaringan kami.

class MLP_ODE(snt.Module):
  """Multi-layer NN ode_fn."""
  def __init__(self, num_hidden, num_layers, num_output, name='mlp_ode'):
    super(MLP_ODE, self).__init__(name=name)
    self._num_hidden = num_hidden
    self._num_output = num_output
    self._num_layers = num_layers
    self._modules = []
    for _ in range(self._num_layers - 1):
      self._modules.append(snt.Linear(self._num_hidden))
      self._modules.append(tf.math.tanh)
    self._modules.append(snt.Linear(self._num_output))
    self._model = snt.Sequential(self._modules)

  def __call__(self, t, inputs):
    inputs = tf.concat([tf.broadcast_to(t, inputs.shape), inputs], -1)
    return self._model(inputs)

Model dan parameter pelatihan

Sekarang kita membangun setumpuk bijector FFJORD. Setiap bijector disediakan dengan ode_solve_fn dan trace_augmentation_fn dan itu sendiri state_derivative_fn Model, sehingga mereka mewakili urutan transformasi yang berbeda.

Bijektor bangunan

Sekarang kita dapat menggunakan TransformedDistribution yang merupakan hasil dari warping base_distribution dengan stacked_ffjord bijector.

transformed_distribution = tfd.TransformedDistribution(
    distribution=base_distribution, bijector=stacked_ffjord)

Sekarang kita mendefinisikan prosedur pelatihan kita. Kami hanya meminimalkan kemungkinan log negatif dari data.

Pelatihan

sampel

Plot sampel dari distribusi dasar dan transformasi.

evaluation_samples = []
base_samples, transformed_samples = get_samples()
transformed_grid = get_transformed_grid()
evaluation_samples.append((base_samples, transformed_samples, transformed_grid))
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
panel_id = 0
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
  1, 4, figsize=(16, 6))
plot_panel(
    grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray, False)
plt.tight_layout()

png

learning_rate = tf.Variable(LR, trainable=False)
optimizer = snt.optimizers.Adam(learning_rate)

for epoch in tqdm.trange(NUM_EPOCHS // 2):
  base_samples, transformed_samples = get_samples()
  transformed_grid = get_transformed_grid()
  evaluation_samples.append(
      (base_samples, transformed_samples, transformed_grid))
  for batch in moons_ds:
    _ = train_step(optimizer, batch)
0%|          | 0/40 [00:00<?, ?it/s]
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/math/ode/base.py:350: calling while_loop_v2 (from tensorflow.python.ops.control_flow_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))
100%|██████████| 40/40 [07:00<00:00, 10.52s/it]
panel_id = -1
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
  1, 4, figsize=(16, 6))
plot_panel(grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray)
plt.tight_layout()

png

Melatihnya lebih lama dengan tingkat pembelajaran menghasilkan peningkatan lebih lanjut.

Tidak dikonversi dalam contoh ini, bijektor FFJORD mendukung estimasi jejak stokastik hutchinson. Estimator tertentu dapat disediakan melalui trace_augmentation_fn . Demikian pula integrator alternatif dapat digunakan dengan mendefinisikan kustom ode_solve_fn .