FFJORD

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本

设置

首先,安装本演示使用的软件包。

pip install -q dm-sonnet

Imports (tf, tfp with adjoint trick, etc)

/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

Helper functions for visualization

FFJORD 双射器

在此 Colab 中,我们将演示 FFJORD 双射器,此双射器最初由 Grathwohl、Will 等人在其论文(arXiv 链接)中提出。

简而言之,这种方式背后的思想是在已知的基础分布数据分布之间建立对应关系。

为了建立这种联系,我们需要进行以下操作:

  1. 在定义基础分布的空间 \(\mathcal{Y}\) 与数据域的空间 \(\mathcal{X}\) 之间定义一个双射映射 \(\mathcal{T}*{\theta}:\mathbf{x} \rightarrow \mathbf{y}\), \(\mathcal{T}*{\theta}^{1}:\mathbf{y} \rightarrow \mathbf{x}\)。
  2. 有效地跟踪我们执行的将概率概念转移到 \(\mathcal{X}\) 上的变形。

在 \(\mathcal{X}\) 上定义的概率分布的以下表达式中对第二个条件进行了形式化:

\[ \log p_{\mathbf{x} }(\mathbf{x})=\log p_{\mathbf{y} }(\mathbf{y})-\log \operatorname{det}\left|\frac{\partial \mathcal{T}_{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| \]

FFJORD 双射器通过定义以下转换来实现这一点:\( \mathcal{T_{\theta} }: \mathbf{x} = \mathbf{z}(t_{0}) \rightarrow \mathbf{y} = \mathbf{z}(t_{1}) \quad : \quad \frac{d \mathbf{z} }{dt} = \mathbf{f}(t, \mathbf{z}, \theta) \)

只要描述状态 \(\mathbf{z}\) 演化的函数 \(\mathbf{f}\) 表现良好,并且可以通过集成以下表达式来计算 log_det_jacobian,则此转换可逆。

\[ \log \operatorname{det}\left|\frac{\partial \mathcal{T}*{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| = -\int*{t_{0} }^{t_{1} } \operatorname{Tr}\left(\frac{\partial \mathbf{f}(t, \mathbf{z}, \theta)}{\partial \mathbf{z}(t)}\right) d t \]

在此演示中,我们将训练 FFJORD 双射器,将高斯分布扭曲到 moons 数据集定义的分布上。这将分 3 个步骤完成:

  • 定义基础分布
  • 定义 FFJORD 双射器
  • 最小化数据集的精确对数似然。

首先,我们加载数据

Dataset

png

接下来,我们实例化基础分布

base_loc = np.array([0.0, 0.0]).astype(np.float32)
base_sigma = np.array([0.8, 0.8]).astype(np.float32)
base_distribution = tfd.MultivariateNormalDiag(base_loc, base_sigma)

我们使用多层感知器对 state_derivative_fn 进行建模。

虽然对此数据集来说并非必需,但使 state_derivative_fn 依赖于时间通常有益。在这里,我们通过将 t 连接到网络的输入来实现这一点。

class MLP_ODE(snt.Module):
  """Multi-layer NN ode_fn."""
  def __init__(self, num_hidden, num_layers, num_output, name='mlp_ode'):
    super(MLP_ODE, self).__init__(name=name)
    self._num_hidden = num_hidden
    self._num_output = num_output
    self._num_layers = num_layers
    self._modules = []
    for _ in range(self._num_layers - 1):
      self._modules.append(snt.Linear(self._num_hidden))
      self._modules.append(tf.math.tanh)
    self._modules.append(snt.Linear(self._num_output))
    self._model = snt.Sequential(self._modules)

  def __call__(self, t, inputs):
    inputs = tf.concat([tf.broadcast_to(t, inputs.shape), inputs], -1)
    return self._model(inputs)

Model and training parameters

现在,我们构造一个 FFJORD 双射器的堆栈。为每个双射器提供 ode_solve_fntrace_augmentation_fn,以及它自己的 state_derivative_fn 模型,因此它们表示一个不同转换的序列。

Building bijector

现在,我们可以使用 TransformedDistribution,这是使用 stacked_ffjord 双射器扭曲 base_distribution 的结果。

transformed_distribution = tfd.TransformedDistribution(
    distribution=base_distribution, bijector=stacked_ffjord)

现在,我们来定义训练过程。只需使数据的负对数似然最小化。

Training

Samples

根据基础分布和转换分布绘制样本。

evaluation_samples = []
base_samples, transformed_samples = get_samples()
transformed_grid = get_transformed_grid()
evaluation_samples.append((base_samples, transformed_samples, transformed_grid))
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
panel_id = 0
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
  1, 4, figsize=(16, 6))
plot_panel(
    grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray, False)
plt.tight_layout()

png

learning_rate = tf.Variable(LR, trainable=False)
optimizer = snt.optimizers.Adam(learning_rate)

for epoch in tqdm.trange(NUM_EPOCHS // 2):
  base_samples, transformed_samples = get_samples()
  transformed_grid = get_transformed_grid()
  evaluation_samples.append(
      (base_samples, transformed_samples, transformed_grid))
  for batch in moons_ds:
    _ = train_step(optimizer, batch)
0%|          | 0/40 [00:00<?, ?it/s]
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/math/ode/base.py:350: calling while_loop_v2 (from tensorflow.python.ops.control_flow_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))
100%|██████████| 40/40 [07:00<00:00, 10.52s/it]
panel_id = -1
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
  1, 4, figsize=(16, 6))
plot_panel(grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray)
plt.tight_layout()

png

通过学习率对其训练更长时间,可获得进一步改善。

本例中未涉及,FFJORD 双射器支持哈钦森的随机迹估算。可通过 trace_augmentation_fn 提供特定的 estimator。同样,也可以通过定义自定义 ode_solve_fn 来使用替代积分器。