FFJORD

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyin Not defterini indir

Kurmak

İlk önce bu demoda kullanılan paketleri kurun.

pip install -q dm-sonnet

İthalatlar (tf, birleşik hileli tfp vb.)

import numpy as np
import tqdm as tqdm
import sklearn.datasets as skd

# visualization
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import kde

# tf and friends
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
import sonnet as snt
tf.enable_v2_behavior()

tfb = tfp.bijectors
tfd = tfp.distributions

def make_grid(xmin, xmax, ymin, ymax, gridlines, pts):
  xpts = np.linspace(xmin, xmax, pts)
  ypts = np.linspace(ymin, ymax, pts)
  xgrid = np.linspace(xmin, xmax, gridlines)
  ygrid = np.linspace(ymin, ymax, gridlines)
  xlines = np.stack([a.ravel() for a in np.meshgrid(xpts, ygrid)])
  ylines = np.stack([a.ravel() for a in np.meshgrid(xgrid, ypts)])
  return np.concatenate([xlines, ylines], 1).T

grid = make_grid(-3, 3, -3, 3, 4, 100)
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

Görselleştirme için yardımcı işlevler

FFJORD bijektör

Bu ortak çalışmada, orijinal olarak Grathwohl, Will ve diğerleri tarafından makalede önerilen FFJORD bijektörünü gösteriyoruz. bağlantıyı arXiv .

Özetle bu tür yaklaşımın arkasındaki fikir, bilinen bir baz dağıtım ve veri dağıtım arasındaki yazışmalar kurmaktır.

Bu bağlantıyı kurmak için yapmamız gereken

  1. Bir örten harita tanımlama \(\mathcal{T}_{\theta}:\mathbf{x} \rightarrow \mathbf{y}\), \(\mathcal{T}_{\theta}^{1}:\mathbf{y} \rightarrow \mathbf{x}\) alanı arasındaki \(\mathcal{Y}\) taban dağılımı tarif edildiği ve boşluk \(\mathcal{X}\) veriler alan.
  2. Verimli biz üzerine olasılık kavramını aktarmak yerine deformasyonların takip \(\mathcal{X}\).

İkinci koşul tanımlanan olasılık dağılımı için aşağıdaki ifade düzenlenmesini \(\mathcal{X}\):

\[ \log p_{\mathbf{x} }(\mathbf{x})=\log p_{\mathbf{y} }(\mathbf{y})-\log \operatorname{det}\left|\frac{\partial \mathcal{T}_{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| \]

FFJORD bijector bunu bir dönüşüm tanımlayarak başarır

\[ \mathcal{T_{\theta} }: \mathbf{x} = \mathbf{z}(t_{0}) \rightarrow \mathbf{y} = \mathbf{z}(t_{1}) \quad : \quad \frac{d \mathbf{z} }{dt} = \mathbf{f}(t, \mathbf{z}, \theta) \]

Uzun bir fonksiyon olarak bu dönüşüm, tersinirdir \(\mathbf{f}\) durumu arasında değişimini tanımlayan \(\mathbf{z}\) iyi huylu ve log_det_jacobian aşağıdaki ifade entegre hesaplanabilir.

\[ \log \operatorname{det}\left|\frac{\partial \mathcal{T}_{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| = -\int_{t_{0} }^{t_{1} } \operatorname{Tr}\left(\frac{\partial \mathbf{f}(t, \mathbf{z}, \theta)}{\partial \mathbf{z}(t)}\right) d t \]

Bu tanıtımda biz tarafından tanımlanan dağıtım üzerine Gauss dağılımı çarpıtmak için bir FFJORD bijector eğitecek moons veri kümesi. Bu 3 adımda yapılacaktır:

  • Baz dağılımını tanımla
  • FFJORD bijektörünü tanımlayın
  • Veri kümesinin tam günlük olasılığını en aza indirin

İlk önce verileri yüklüyoruz

veri kümesi

png

Ardından, bir temel dağılımı başlatıyoruz

base_loc = np.array([0.0, 0.0]).astype(np.float32)
base_sigma = np.array([0.8, 0.8]).astype(np.float32)
base_distribution = tfd.MultivariateNormalDiag(base_loc, base_sigma)

Biz modeli için perseptron çok katmanı kullanır state_derivative_fn .

Bu veri kümesi için gerekli olmasa da, bunu yapmak için sık sık benefitial olan state_derivative_fn zaman bağımlı. Burada birleştirerek bunu başarmak t ağımızın girişlerine.

class MLP_ODE(snt.Module):
  """Multi-layer NN ode_fn."""
  def __init__(self, num_hidden, num_layers, num_output, name='mlp_ode'):
    super(MLP_ODE, self).__init__(name=name)
    self._num_hidden = num_hidden
    self._num_output = num_output
    self._num_layers = num_layers
    self._modules = []
    for _ in range(self._num_layers - 1):
      self._modules.append(snt.Linear(self._num_hidden))
      self._modules.append(tf.math.tanh)
    self._modules.append(snt.Linear(self._num_output))
    self._model = snt.Sequential(self._modules)

  def __call__(self, t, inputs):
    inputs = tf.concat([tf.broadcast_to(t, inputs.shape), inputs], -1)
    return self._model(inputs)

Model ve eğitim parametreleri

Şimdi bir FFJORD bijektör yığını oluşturuyoruz. Her bijector ile sağlanan ode_solve_fn ve trace_augmentation_fn ve kendi var state_derivative_fn farklı dönüşümlerin bir diziyi temsil böylece, modeli.

Bina bijektörü

Şimdi kullanabilirsiniz TransformedDistribution çözgü sonucudur base_distribution ile stacked_ffjord bijector.

transformed_distribution = tfd.TransformedDistribution(
    distribution=base_distribution, bijector=stacked_ffjord)

Şimdi eğitim prosedürümüzü tanımlıyoruz. Verilerin negatif log olasılığını en aza indiririz.

Eğitim

örnekler

Temel ve dönüştürülmüş dağılımlardan örnekleri çizin.

evaluation_samples = []
base_samples, transformed_samples = get_samples()
transformed_grid = get_transformed_grid()
evaluation_samples.append((base_samples, transformed_samples, transformed_grid))
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
panel_id = 0
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
  1, 4, figsize=(16, 6))
plot_panel(
    grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray, False)
plt.tight_layout()

png

learning_rate = tf.Variable(LR, trainable=False)
optimizer = snt.optimizers.Adam(learning_rate)

for epoch in tqdm.trange(NUM_EPOCHS // 2):
  base_samples, transformed_samples = get_samples()
  transformed_grid = get_transformed_grid()
  evaluation_samples.append(
      (base_samples, transformed_samples, transformed_grid))
  for batch in moons_ds:
    _ = train_step(optimizer, batch)
0%|          | 0/40 [00:00<?, ?it/s]
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/math/ode/base.py:350: calling while_loop_v2 (from tensorflow.python.ops.control_flow_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))
100%|██████████| 40/40 [07:00<00:00, 10.52s/it]
panel_id = -1
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
  1, 4, figsize=(16, 6))
plot_panel(grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray)
plt.tight_layout()

png

Öğrenme oranıyla daha uzun süre eğitmek, daha fazla iyileştirme sağlar.

Bu örnekte dönüştürülmeyen FFJORD bijektörü, hutchinson'ın stokastik iz tahminini destekler. Özellikle tahmin yoluyla sağlanabilir trace_augmentation_fn . Benzer şekilde, alternatif entegratörleri özel tanımlayarak kullanılabilir ode_solve_fn .