# JAX 上的 TensorFlow Probability

TensorFlow Probability (TFP) 是用于进行概率推理和统计分析的库，现在也可以在 JAX 上运行！对于不熟悉 JAX 的人来说，JAX 是用于根据可组合的函数转换来加快数值计算的库。

JAX 上的 TFP 支持常规 TFP 中大量极为有用的功能，同时还保留了许多 TFP 用户现在习惯使用的抽象和 API。

## 设置

JAX 上的 TFP 依赖于 TensorFlow；我们将 TensorFlow 从本 Colab 中完全卸载！

````pip uninstall tensorflow -y -q`
```

````pip install -Uq tfp-nightly[jax] > /dev/null`
```

``````import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn import datasets
sns.set(style='white')
``````
```/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
import pandas.util.testing as tm
```

``````import jax.numpy as jnp
from jax import jit
from jax import random
from jax import vmap
``````

## 在 JAX 上导入 TFP

``````from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels
``````

## 演示：贝叶斯逻辑回归

``````iris = datasets.load_iris()
features, labels = iris['data'], iris['target']

num_features = features.shape[-1]
num_classes = len(iris.target_names)
``````

``````Root = tfd.JointDistributionCoroutine.Root
def model():
w = yield Root(tfd.Sample(tfd.Normal(0., 1.),
sample_shape=(num_features, num_classes)))
b = yield Root(
tfd.Sample(tfd.Normal(0., 1.), sample_shape=(num_classes,)))
logits = jnp.dot(features, w) + b
yield tfd.Independent(tfd.Categorical(logits=logits),
reinterpreted_batch_ndims=1)

dist = tfd.JointDistributionCoroutine(model)
def target_log_prob(*params):
return dist.log_prob(params + (labels,))
``````

``````init_key, sample_key = random.split(random.PRNGKey(0))
init_params = tuple(dist.sample(seed=init_key)[:-1])

@jit
def run_chain(key, state):
kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-3)
return tfp.mcmc.sample_chain(500,
current_state=state,
kernel=kernel,
trace_fn=lambda _, results: results.target_log_prob,
num_burnin_steps=500,
seed=key)

states, log_probs = run_chain(sample_key, init_params)
plt.figure()
plt.plot(log_probs)
plt.ylabel('Target Log Prob')
plt.xlabel('Iterations of NUTS')
plt.show()
``````

``````def classifier_probs(params):
dists, _ = dist.sample_distributions(seed=random.PRNGKey(0),
value=params + (None,))
return dists[-1].distribution.probs_parameter()
``````

``````all_probs = jit(vmap(classifier_probs))(states)
print('Average accuracy:', jnp.mean(all_probs.argmax(axis=-1) == labels))
print('BMA accuracy:', jnp.mean(all_probs.mean(axis=0).argmax(axis=-1) == labels))
``````
```Average accuracy: 0.96952
BMA accuracy: 0.97999996
```

BMA 似乎可以将我们的错误率减少差不多三分之一！

## 基本原理

JAX 上的 TFP 与 TF 具有相同的 API，它接受 JAX 模拟量，而不是 `tf.Tensor` 等 TF 对象。例如，在 `tf.Tensor` 以前用作输入的位置，该 API 现在应该使用 JAX `DeviceArray`。TFP 方法将返回 `DeviceArray`，而不是 `tf.Tensor`。JAX 上的 TFP 也使用 JAX 对象的嵌套结构，例如 `DeviceArray` 列表或字典。

## 分布

TFP 的大多数分布都可以在 JAX 中使用，其语义与 TF 对应项极为类似。这些分布也作为 JAX Pytree 注册，因此可以作为 JAX 转换函数的输入和输出。

### 基本分布

``````dist = tfd.Normal(0., 1.)
print(dist.log_prob(0.))
``````
```-0.9189385
```

``````tfd.Normal(0., 1.).sample(seed=random.PRNGKey(0))
``````
```DeviceArray(-0.20584226, dtype=float32)
```

``````dist = tfd.MultivariateNormalDiag(
loc=jnp.zeros(5),
scale_diag=jnp.ones(5)
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
``````
```Event shape: (5,)
Batch shape: ()
```

``````dist = tfd.Normal(
loc=jnp.ones(5),
scale=jnp.ones(5),
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
``````
```Event shape: ()
Batch shape: (5,)
```

``````dist =  tfd.Normal(jnp.zeros(5), jnp.ones(5))
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)

dist =  tfd.Independent(tfd.Normal(jnp.zeros(5), jnp.ones(5)), 1)
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)
``````
```(10, 2, 5)
(10, 2)
```

``````sns.distplot(tfd.Normal(0., 1.).sample(1000, seed=random.PRNGKey(0)))
plt.show()
``````

`Distribution` 方法与 JAX 转换兼容。

``````sns.distplot(jit(vmap(lambda key: tfd.Normal(0., 1.).sample(seed=key)))(
random.split(random.PRNGKey(0), 2000)))
plt.show()
``````

``````x = jnp.linspace(-5., 5., 100)
plt.show()
``````

``````@jit
def random_distribution(key):
loc_key, scale_key = random.split(key)
loc, log_scale = random.normal(loc_key), random.normal(scale_key)
return tfd.Normal(loc, jnp.exp(log_scale))
random_dist = random_distribution(random.PRNGKey(0))
print(random_dist.mean(), random_dist.variance())
``````
```0.14389051 0.081832744
```

### 转换的分布

``````dist = tfd.TransformedDistribution(
tfd.Normal(0., 1.),
tfb.Sigmoid()
)
sns.distplot(dist.sample(1000, seed=random.PRNGKey(0)))
plt.show()
``````

### 联合分布

TFP 提供了 `JointDistribution`，可用于将各个组件分布合并为多个随机变量的单一分布。目前，TFP 提供了三个核心变体（`JointDistributionSequential``JointDistributionNamed``JointDistributionCoroutine`），它们均可以在 JAX 中使用。另外，JAX 也支持 `AutoBatched` 的所有变体。

``````dist = tfd.JointDistributionSequential([
tfd.Normal(0., 1.),
lambda x: tfd.Normal(x, 1e-1)
])
plt.scatter(*dist.sample(1000, seed=random.PRNGKey(0)), alpha=0.5)
plt.show()
``````

``````joint = tfd.JointDistributionNamed(dict(
e=             tfd.Exponential(rate=1.),
n=             tfd.Normal(loc=0., scale=2.),
m=lambda n, e: tfd.Normal(loc=n, scale=e),
x=lambda    m: tfd.Sample(tfd.Bernoulli(logits=m), 12),
))
joint.sample(seed=random.PRNGKey(0))
``````
```{'e': DeviceArray(3.376818, dtype=float32),
'm': DeviceArray(2.5449684, dtype=float32),
'n': DeviceArray(-0.6027825, dtype=float32),
'x': DeviceArray([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
```
``````Root = tfd.JointDistributionCoroutine.Root
def model():
e = yield Root(tfd.Exponential(rate=1.))
n = yield Root(tfd.Normal(loc=0, scale=2.))
m = yield tfd.Normal(loc=n, scale=e)
x = yield tfd.Sample(tfd.Bernoulli(logits=m), 12)

joint = tfd.JointDistributionCoroutine(model)

joint.sample(seed=random.PRNGKey(0))
``````
```StructTuple(var0=DeviceArray(0.17315261, dtype=float32), var1=DeviceArray(-3.290489, dtype=float32), var2=DeviceArray(-3.1949058, dtype=float32), var3=DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32))
```

### 其他分布

``````k1, k2, k3 = random.split(random.PRNGKey(0), 3)
observation_noise_variance = 0.01
f = lambda x: jnp.sin(10*x[..., 0]) * jnp.exp(-x[..., 0]**2)
observation_index_points = random.uniform(
k1, [50], minval=-1.,maxval= 1.)[..., jnp.newaxis]
observations = f(observation_index_points) + tfd.Normal(
loc=0., scale=jnp.sqrt(observation_noise_variance)).sample(seed=k2)

index_points = jnp.linspace(-1., 1., 100)[..., jnp.newaxis]

gprm = tfd.GaussianProcessRegressionModel(
kernel=kernel,
index_points=index_points,
observation_index_points=observation_index_points,
observations=observations,
observation_noise_variance=observation_noise_variance)

samples = gprm.sample(10, seed=k3)
for i in range(10):
plt.plot(index_points, samples[i], alpha=0.5)
plt.plot(observation_index_points, observations, marker='o', linestyle='')
plt.show()
``````

JAX 也支持隐马尔可夫模型。

``````initial_distribution = tfd.Categorical(probs=[0.8, 0.2])
transition_distribution = tfd.Categorical(probs=[[0.7, 0.3],
[0.2, 0.8]])

observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.])

model = tfd.HiddenMarkovModel(
initial_distribution=initial_distribution,
transition_distribution=transition_distribution,
observation_distribution=observation_distribution,
num_steps=7)

print(model.mean())
print(model.log_prob(jnp.zeros(7)))
print(model.sample(seed=random.PRNGKey(0)))
``````
```[3.       6.       7.5      8.249999 8.625001 8.812501 8.90625 ]
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/substrates/jax/distributions/hidden_markov_model.py:483: UserWarning: HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug in which the transition model was applied prior to the initial step. This bug has been fixed. You may observe a slight change in behavior.
'HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug '
-19.855635
[ 1.3641367  0.505798   1.3626463  3.6541772  2.272286  15.10309
22.794212 ]
```

JAX 还不支持 `PixelCNN` 等少数几种分布，因为它们严重依赖于 TensorFlow 或者与 XLA 不兼容。

## 双射器

``````tfb.Exp().inverse(1.)
``````
```DeviceArray(0., dtype=float32)
```
``````bij = tfb.Shift(1.)(tfb.Scale(3.))
print(bij.forward(jnp.ones(5)))
print(bij.inverse(jnp.ones(5)))
``````
```[4. 4. 4. 4. 4.]
[0. 0. 0. 0. 0.]
```
``````b = tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=None)
print(b.forward(x=[0., 0., 0.]))
print(b.inverse(y=[[1., 0], [.5, 2]]))
``````
```[[1. 0.]
[0. 1.]]
[0.6931472 0.5       0.       ]
```
``````b = tfb.Chain([tfb.Exp(), tfb.Softplus()])
# or:
# b = tfb.Exp()(tfb.Softplus())
print(b.forward(-jnp.ones(5)))
``````
```[1.3678794 1.3678794 1.3678794 1.3678794 1.3678794]
```

``````jit(vmap(tfb.Exp().inverse))(jnp.arange(4.))
``````
```DeviceArray([     -inf, 0.       , 0.6931472, 1.0986123], dtype=float32)
```
``````x = jnp.linspace(0., 1., 100)
plt.show()
``````

## MCMC

``````target_log_prob = tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2)).log_prob
``````

``````def run_chain(key, state):
kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-1)
return tfp.mcmc.sample_chain(1000,
current_state=state,
kernel=kernel,
trace_fn=lambda _, results: results.target_log_prob,
seed=key)
states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros(2))
plt.figure()
plt.scatter(*states.T, alpha=0.5)
plt.figure()
plt.plot(log_probs)
plt.show()
``````

``````states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros([10, 2]))
plt.figure()
for i in range(10):
plt.scatter(*states[:, i].T, alpha=0.5)
plt.figure()
for i in range(10):
plt.plot(log_probs[:, i], alpha=0.5)
plt.show()
``````

## 优化器

JAX 上的 TFP 支持一些重要的优化器，例如 BFGS 和 L-BFGS。我们来设置一个简单的定标平方损失函数。

``````minimum = jnp.array([1.0, 1.0])  # The center of the quadratic bowl.
scales = jnp.array([2.0, 3.0])  # The scales along the two axes.

# The objective function and the gradient.
return jnp.sum(scales * jnp.square(x - minimum))

start = jnp.array([0.6, 0.8])  # Starting point for the search.
``````

BFGS 可以找到此损失的最小值。

``````optim_results = tfp.optimizer.bfgs_minimize(

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
``````
```Function evaluations: 5
```

L-BFGS 也可以找到此损失的最小值。

``````optim_results = tfp.optimizer.lbfgs_minimize(

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
``````
```Function evaluations: 5
```

``````def optimize_single(start):
return tfp.optimizer.lbfgs_minimize(

all_results = jit(vmap(optimize_single))(
random.normal(random.PRNGKey(0), (10, 2)))
assert all(all_results.converged)
for i in range(10):
np.testing.assert_allclose(optim_results.position[i], minimum)
print("Function evaluations: %s" % all_results.num_objective_evaluations)
``````
```Function evaluations: [6 6 9 6 6 8 6 8 5 9]
```

## 注意事项

TF 与 JAX 之间存在一些根本区别，有些 TFP 行为在这两种基质之间有所不同，并不是所有功能均受支持。例如，

## 附录：JAX 中的伪随机性

JAX 的伪随机数生成 (PRNG) 模型无状态。有状态模型的全局状态在每次随机绘制后都会发生变化，与有状态模型不同的是，该模型没有可变的全局状态。在 JAX 模型中，我们从 PRNG 密钥开始，该密钥类似于一对 32 位整数。我们可以使用 `jax.random.PRNGKey` 来构造这些密钥。

``````key = random.PRNGKey(0)  # Creates a key with value [0, 0]
print(key)
``````
```[0 0]
```

JAX 中的随机函数使用密钥确切地生成随机变量，也就是说，这些变量不应当重复使用。例如，我们可以使用 `key` 对正态分布值抽样，但我们不应当在其他位置再使用 `key`。此外，如果将同一个值传递给 `random.normal`，也会得到相同的值。

``````print(random.normal(key))
``````
```-0.20584226
```

``````key1, key2 = random.split(key, num=2)
print(key1, key2)
``````
```[4146024105  967050713] [2718843009 1272950319]
```

``````print(random.normal(key1), random.normal(key2))
``````
```0.14389051 -1.2515389
```

[]
[]