Auto-Batched Joint Distributions: A Gentle Tutorial

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本

简介

TensorFlow Probability (TFP) 提供了许多 JointDistribution 抽象,允许用户以类似数学的形式轻松表示概率图模型,从而使概率推断更加容易;抽象会生成用于从模型进行采样并评估来自模型的样本的对数概率的方法。在本教程中,我们将回顾“自动批处理”变体,它是在原始的 JointDistribution 抽象之后开发的。相对于原始的非自动批处理抽象,自动批处理版本使用起来更简单,且更符合工效学,可以用更少的样板来表达许多模型。在本 Colab 中,我们将详细(也许比较乏味)探索一个简单的模型,明确自动批处理能够解决的问题,并(希望)能够在此过程中向读者介绍更多有关 TFP 形状的概念。

在引入自动批处理之前,JointDistribution 有几种不同的变体,对应于用于表达概率模型的不同句法样式:JointDistributionSequentialJointDistributionNamedJointDistributionCoroutine。自动批处理以混入的形式存在,因此我们现在有了所有这些的 AutoBatched 变体。在本教程中,我们将探讨 JointDistributionSequentialJointDistributionSequentialAutoBatched 之间的区别;不过,我们在这里所做的一切都适用于其他变体,基本上没有任何变化。

依赖项和前提条件

Import and set ups

前提条件:贝叶斯回归问题

我们将考虑一个非常简单的贝叶斯回归场景:

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

在此模型中,mb 从标准正态分布中抽取,观测值 Y 从均值取决于随机变量 mb 的正态分布中抽取,以及一些(非随机、已知的)协变量 X。(为了简明起见,在本例中,我们假设所有随机变量的尺度已知。)

要在此模型中执行推断,我们需要同时知道协变量 X 和观测值 Y,但在本教程中,我们只需要 X,因此我们定义一个简单的虚拟 X

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

必要条件

在概率推断中,我们通常要执行两个基本运算:

  • sample:从模型中抽取样本。
  • log_prob:计算模型中样本的对数概率。

TFP 的 JointDistribution 抽象(以及许多其他概率编程方式)的关键作用是允许用户编写一次模型,并同时访问 samplelog_prob 计算。

请注意,我们的数据集中有 7 个点 (X.shape = (7,)),我们现在可以说明出色的 JointDistribution 所需的必要条件:

  • sample() 应生成一个形状为 [(), (), (7,)] 的 Tensors 的列表,分别对应于标量斜率、标量偏差和向量观测值。
  • log_prob(sample()) 应生成一个标量:特定斜率、偏差和观测值的对数概率。
  • sample([5, 3]) 应生成一个形状为 [(5, 3), (5, 3), (5, 3, 7)]Tensors 的列表,代表模型中样本的 (5, 3) 批次
  • log_prob(sample([5, 3])) 应生成一个形状为 (5, 3) 的 Tensor

现在,我们来看一系列 JointDistribution 模型,了解如何实现上述必要条件,并希望能够在此过程中更详细地了解 TFP 形状。

剧透警告:无需添加样板即可满足上述必要条件的方式是自动批处理

首次尝试;JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

这或多或少是模型到代码的直接转换。斜率 m 和偏差 b 很简单。Y 使用 lambda 函数定义:一般模式是,JointDistributionSequential (JDS) 中 \(k\) 参数的 lambda 函数会使用模型中先前的 \(k\) 分布。请注意“反向”顺序。

我们将调用 sample_distributions,它会同时返回样本用来生成该样本的底层“子分布” 。(我们可以通过调用 sample 只生成样本;在本教程的后面,获得分布也会很方便。)我们生成的样本不错:

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

log_prob 生成的结果所具有的形状不理想:

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

而且多次采样也不起作用:

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

我们来尝试了解一下出了什么问题。

简要回顾:批次形状和事件形状

在 TFP 中,普通的(非 JointDistribution)概率分布具有事件形状批次形状,了解两者的区别对于高效使用 TFP 至关重要:

  • 事件形状描述了从分布中单次抽样的形状;抽样可能依赖于不同维度。对于标量分布,事件形状为 []。对于 5 维多元正态,事件形状为 [5]。
  • 批次形状描述了独立的、分布不均匀的抽样,也称为分布的“批次”。在一个 Python 对象中表示一批分布是 TFP 实现大规模效率的关键方式之一。

就我们的目的而言,需要记住的一个关键事实是,如果我们对分布中的单个样本调用 log_prob,结果将始终具有与批次形状相匹配(即,具有最右侧的维度)的形状。

有关形状的更深入讨论,请参阅“理解 TensorFlow 分布形状”教程。

为什么 log_prob(sample()) 不生成标量?

我们来利用批次形状和事件形状的知识来探讨一下 log_prob(sample()) 的情况。同样,下面是我们的样本:

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

下面是我们的分布:

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

对数概率的计算方法是:在各部分的(匹配)元素处对子分布的对数概率求和:

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

因此,有一层解释是,对数概率计算返回的是形状为 7 的张量,因为 log_prob_parts 的第三个子分量是形状为 7 的张量。但是为什么呢?

我们看到 dists 的最后一个元素(对应于我们数学公式中在 Y 上的分布)的 batch_shape[7]。换句话说,我们在 Y 上的分布是由 7 个独立的正态分布(具有不同均值,在本例中具有相同尺度)组成的批次。

现在,我们了解了问题所在:在 JDS 中,Y 上的分布具有 batch_shape=[7],来自 JDS 的样本代表 mb 的标量,以及由 7 个独立的正态分布组成的“批次”。log_prob 会计算 7 个单独的对数概率,每个概率代表抽样 mb 的对数概率,以及在某个 X[i] 处的单个观测值 Y[i]

使用Independent 修复 log_prob(sample())

回想一下,dists[2] 具有 event_shape=[]batch_shape=[7]

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

通过使用 TFP 的 Independent 元分布(将批次维度转换为事件维度),我们可以将其转换为具有 event_shape=[7]batch_shape=[] 的分布(我们将其重命名为 y_dist_i,因为它是 Y 上的分布,其中 _i 代表我们的 Independent 封装):

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

现在,7 向量的 log_prob 是一个标量:

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

在幕后,Independent 会对批次求和:

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

事实上,我们可以用它来构造一个新的 jds_i(同样,i 代表 Independent),其中 log_prob 会返回一个标量:

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

下面是几个注意事项:

  • jds_i.log_prob(s)tf.reduce_sum(jds.log_prob(s)) 不同。前者会生成联合分布的“正确”对数概率。后者会在形状为 7 的张量上求和,它的每个元素都是 mb 的对数概率,以及 Y 的对数概率的单个元素的和,因此它多算了 mb。(log_prob(m) + log_prob(b) + log_prob(Y) 会返回结果,而非引发异常,因为 TFP 遵循 TF 和 NumPy 的广播规则;将一个标量添加到一个向量会生成一个向量大小的结果。)
  • 在这种特殊情况下,我们可以使用 MultivariateNormalDiag 代替 Independent(Normal(...)) 来解决问题并获得相同的结果。MultivariateNormalDiag 是向量值分布(即,它已经具有向量事件形状)。实际上,MultivariateNormalDiag 可以(但没有)实现为 IndependentNormal 的组合。值得记住的是,给定向量 V,无法区分来自 n1 = Normal(loc=V)n2 = MultivariateNormalDiag(loc=V) 的样本;这些分布的区别在于,n1.log_prob(n1.sample()) 是向量,而 n2.log_prob(n2.sample()) 是标量。

多个样本?

抽取多个样本仍然不起作用:

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

我们来思考一下原因。当我们调用 jds_i.sample([5, 3]) 时,我们会先抽取 mb 的样本,每个样本都具有形状 (5, 3)。接下来,我们将尝试通过以下方式构造 Normal 分布:

tfd.Normal(loc=m*X + b, scale=1.)

但是,如果 m 的形状为 (5, 3)X 的形状为 7,我们就无法将它们相乘,事实上这就是我们遇到的错误:

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

要解决这个问题,我们来思考一下 Y 上的分布必须具备哪些属性。如果我们调用了 jds_i.sample([5, 3]),那么我们知道 mb 的形状都将为 (5, 3)。在 Y 分布上调用 sample 应该生成什么形状?显而易见的答案是 (5, 3, 7):对于每个批次点,我们希望得到一个与 X 大小相同的样本。我们可以通过使用 TensorFlow 的广播功能并添加额外维度来实现此目的:

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

mb 各添加一个轴,我们可以定义一个支持多个样本的新 JDS:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

作为额外的检查,我们将验证单个批次点的对数概率是否与我们之前的匹配:

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

自动批处理的胜利

非常好!我们现在有一个能够处理我们所有必要条件的 JointDistribution 版本:由于使用了 tfd.Independentlog_prob 会返回标量,并且由于我们通过添加额外的轴修复了广播,现在多个样本也能正常工作。

如果我告诉您还有一种更简单、更好的方式呢?这种方式确实存在,它叫作 JointDistributionSequentialAutoBatched (JDSAB):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

它的工作原理是什么?虽然您可以尝试通过阅读代码来进行深入了解,但我们将给出一个简短的概述,这对大多数用例来说已经足够:

  • 回想一下,我们的第一个问题是,Y 的分布具有 batch_shape=[7]event_shape=[],我们使用了 Independent 将批次维度转换为事件维度。JDSAB 会忽略分量分布的批次形状;而会将批次形状视为模型的整体属性,假定为 [](除非通过设置 batch_ndims > 0 另行指定)。这个效果等同于使用 tfd.Independent 将分量分布的所有批次维度转换为事件维度,就像我们在上面手动完成的那样。
  • 我们的第二个问题是需要对 mb 的形状进行改动,以便在创建多个样本时可以使用 X 进行适当的广播。您可以使用 JDSAB 编写一个模型来生成单个样本,然后使用 TensorFlow 的 vectorized_map “提升”整个模型来生成多个样本。(此功能与 JAX 的 vmap 类似。)

更深入地探索批次形状问题,我们可以比较原始“不良”联合分布 jds,批次固定分布 jds_ijds_ia,以及自动批处理的 jds_ab

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

我们看到,原始 jds 的子分布具有不同的批次形状。jds_ijds_ia 通过创建具有相同(空)批次形状的子分布来解决这个问题。jds_ab 只有一个(空)批次形状。

值得注意的是,JointDistributionSequentialAutoBatched 免费提供了一些其他通用性。假设我们将协变量 X(并隐式地将观测值 Y)设为二维:

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

我们的 JointDistributionSequentialAutoBatched 无需更改即可工作(我们需要重新定义模型,因为X 的形状已由 jds_ab.log_prob 缓存):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

另一方面,我们精心设计的 JointDistributionSequential 已不再有效:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

要解决此问题,我们必须向 mb 都添加一个 tf.newaxis 以匹配形状,并在 Independent 的调用中将 reinterpreted_batch_ndims 增加到 2。在本例中,让自动批处理机制来处理形状问题更简短、更轻松,且更符合工效学。

再一次,我们注意到,虽然此笔记本探讨了 JointDistributionSequentialAutoBatched,但 JointDistribution 的其他变体具有等效的 AutoBatched。(对于 JointDistributionCoroutine 的用户来说,JointDistributionCoroutineAutoBatched 的另一个好处是您不再需要指定 Root 节点;如果您从未使用过 JointDistributionCoroutine,可以放心地忽略此语句。)

结束语

在此笔记本中,我们介绍了 JointDistributionSequentialAutoBatched,并详细演示了一个简单的示例。希望您从中了解到了一些有关 TFP 形状和自动批处理的知识!