了解 TensorFlow Distributions 形状

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本
import collections

import tensorflow as tf
tf.compat.v2.enable_v2_behavior()

import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

基础知识

与 TensorFlow Distributions 形状相关的三个重要概念如下:

  • 事件形状描述了从分布中单次抽样的形状;抽样可能依赖于不同维度。对于标量分布,事件形状为 []。对于 5 维多元正态分布,事件形状为 [5]
  • 批次形状描述了独立但并非同分布的抽样,也称为分布的“批次”。
  • 样本形状描述了来自分布系列的独立同分布批次抽样。

事件形状和批次形状是 Distribution 对象的属性,而样本形状则与对 samplelog_prob 的特定调用相关联。

此笔记本的目的是通过示例说明这些概念,因此,如果没有立即看到效果,请不要担心!

有关这些概念的另一种概念性概述,请参阅此博文

关于 TensorFlow Eager 的说明

整个笔记本使用 TensorFlow Eager 编写。在 Eager 模式下,尽管在 Python 中创建 Distribution 对象时会评估(并因此已知)分布批次和事件形状,提出的概念也完全不依赖于 Eager,而在计算图(非 Eager 模式)中,可以定义事件和批次的形状在运行计算图之后才会确定的分布。

标量分布

如上所述,Distribution 对象已定义事件和批次形状。我们将从描述分布的实用工具开始:

def describe_distributions(distributions):
  print('\n'.join([str(d) for d in distributions]))

在本部分中,我们将探讨标量分布:事件形状为 [] 的分布。一个典型示例是由 rate 指定的泊松分布:

poisson_distributions = [
    tfd.Poisson(rate=1., name='One Poisson Scalar Batch'),
    tfd.Poisson(rate=[1., 10., 100.], name='Three Poissons'),
    tfd.Poisson(rate=[[1., 10., 100.,], [2., 20., 200.]],
                name='Two-by-Three Poissons'),
    tfd.Poisson(rate=[1.], name='One Poisson Vector Batch'),
    tfd.Poisson(rate=[[1.]], name='One Poisson Expanded Batch')
]

describe_distributions(poisson_distributions)
tfp.distributions.Poisson("One_Poisson_Scalar_Batch", batch_shape=[], event_shape=[], dtype=float32)
tfp.distributions.Poisson("Three_Poissons", batch_shape=[3], event_shape=[], dtype=float32)
tfp.distributions.Poisson("Two_by_Three_Poissons", batch_shape=[2, 3], event_shape=[], dtype=float32)
tfp.distributions.Poisson("One_Poisson_Vector_Batch", batch_shape=[1], event_shape=[], dtype=float32)
tfp.distributions.Poisson("One_Poisson_Expanded_Batch", batch_shape=[1, 1], event_shape=[], dtype=float32)

泊松分布是一种标量分布,因此其事件形状始终为 []。如果我们指定更多比率,则这些比率将显示在批次形状中。最后一对示例十分有趣:只有一个比率,但由于该比率已嵌入具有非空形状的 NumPy 数组中,因此该形状成为批次形状。

标准正态分布也是标量。与泊松分布一样,其事件的形状也是 [],但我们将使用它来查看我们的第一个广播示例。使用 locscale 参数指定正态:

normal_distributions = [
    tfd.Normal(loc=0., scale=1., name='Standard'),
    tfd.Normal(loc=[0.], scale=1., name='Standard Vector Batch'),
    tfd.Normal(loc=[0., 1., 2., 3.], scale=1., name='Different Locs'),
    tfd.Normal(loc=[0., 1., 2., 3.], scale=[[1.], [5.]],
               name='Broadcasting Scale')
]

describe_distributions(normal_distributions)
tfp.distributions.Normal("Standard", batch_shape=[], event_shape=[], dtype=float32)
tfp.distributions.Normal("Standard_Vector_Batch", batch_shape=[1], event_shape=[], dtype=float32)
tfp.distributions.Normal("Different_Locs", batch_shape=[4], event_shape=[], dtype=float32)
tfp.distributions.Normal("Broadcasting_Scale", batch_shape=[2, 4], event_shape=[], dtype=float32)

上面有趣的示例是 Broadcasting Scale 分布。loc 参数的形状为 [4],而 scale 参数的形状为 [2, 1]。使用 Numpy 广播规则时,批次形状为 [2, 4]。一种定义 "Broadcasting Scale" 分布的等效方式(但不太简洁,因此不推荐使用)为:

describe_distributions(
    [tfd.Normal(loc=[[0., 1., 2., 3], [0., 1., 2., 3.]],
                scale=[[1., 1., 1., 1.], [5., 5., 5., 5.]])])
tfp.distributions.Normal("Normal", batch_shape=[2, 4], event_shape=[], dtype=float32)

我们可以看到广播符号为什么有用,尽管它也是麻烦和错误的来源。

对标量分布抽样

我们可以对分布做的两件主要事情:从分布中 sample,以及计算 log_prob。我们先探讨抽样。基本规则是,当我们从分布中抽样时,所生成张量的形状为 [sample_shape, batch_shape, event_shape],其中 batch_shapeevent_shapeDistribution 对象提供,而 sample_shape 由对 sample 的调用提供。对于标量分布,event_shape = [],因此从样本返回的张量的形状为 [sample_shape, batch_shape]。我们来试一下:

def describe_sample_tensor_shape(sample_shape, distribution):
    print('Sample shape:', sample_shape)
    print('Returned sample tensor shape:',
          distribution.sample(sample_shape).shape)

def describe_sample_tensor_shapes(distributions, sample_shapes):
    started = False
    for distribution in distributions:
      print(distribution)
      for sample_shape in sample_shapes:
        describe_sample_tensor_shape(sample_shape, distribution)
      print()

sample_shapes = [1, 2, [1, 5], [3, 4, 5]]
describe_sample_tensor_shapes(poisson_distributions, sample_shapes)
tfp.distributions.Poisson("One_Poisson_Scalar_Batch", batch_shape=[], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1,)
Sample shape: 2
Returned sample tensor shape: (2,)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5)

tfp.distributions.Poisson("Three_Poissons", batch_shape=[3], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 3)
Sample shape: 2
Returned sample tensor shape: (2, 3)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 3)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 3)

tfp.distributions.Poisson("Two_by_Three_Poissons", batch_shape=[2, 3], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 2, 3)
Sample shape: 2
Returned sample tensor shape: (2, 2, 3)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 2, 3)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 2, 3)

tfp.distributions.Poisson("One_Poisson_Vector_Batch", batch_shape=[1], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 1)
Sample shape: 2
Returned sample tensor shape: (2, 1)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 1)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 1)

tfp.distributions.Poisson("One_Poisson_Expanded_Batch", batch_shape=[1, 1], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 1, 1)
Sample shape: 2
Returned sample tensor shape: (2, 1, 1)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 1, 1)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 1, 1)
describe_sample_tensor_shapes(normal_distributions, sample_shapes)
tfp.distributions.Normal("Standard", batch_shape=[], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1,)
Sample shape: 2
Returned sample tensor shape: (2,)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5)

tfp.distributions.Normal("Standard_Vector_Batch", batch_shape=[1], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 1)
Sample shape: 2
Returned sample tensor shape: (2, 1)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 1)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 1)

tfp.distributions.Normal("Different_Locs", batch_shape=[4], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 4)
Sample shape: 2
Returned sample tensor shape: (2, 4)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 4)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 4)

tfp.distributions.Normal("Broadcasting_Scale", batch_shape=[2, 4], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 2, 4)
Sample shape: 2
Returned sample tensor shape: (2, 2, 4)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 2, 4)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 2, 4)

这就是关于 sample 的所有内容:返回的样本张量的形状为 [sample_shape, batch_shape, event_shape]

计算标量分布的 log_prob

现在,我们看一下有点棘手的 log_problog_prob 接受(非空)张量作为输入,此张量表示要计算分布的 log_prob 的位置。在最简单的情况下,此张量将具有 [sample_shape, batch_shape, event_shape] 形式的形状,其中 batch_shapeevent_shape 与分布的批次和事件形状匹配。再次回想一下,对于标量分布,event_shape = [],因此输入张量的形状为 [sample_shape, batch_shape]。在本例中,我们会重新得到形状为 [sample_shape, batch_shape] 的张量:

three_poissons = tfd.Poisson(rate=[1., 10., 100.], name='Three Poissons')
three_poissons
<tfp.distributions.Poisson 'Three_Poissons' batch_shape=[3] event_shape=[] dtype=float32>
three_poissons.log_prob([[1., 10., 100.], [100., 10., 1]])  # sample_shape is [2].
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[  -1.       ,   -2.0785608,   -3.2223587],
       [-364.73938  ,   -2.0785608,  -95.39484  ]], dtype=float32)>
three_poissons.log_prob([[[[1., 10., 100.], [100., 10., 1.]]]])  # sample_shape is [1, 1, 2].
<tf.Tensor: shape=(1, 1, 2, 3), dtype=float32, numpy=
array([[[[  -1.       ,   -2.0785608,   -3.2223587],
         [-364.73938  ,   -2.0785608,  -95.39484  ]]]], dtype=float32)>

请注意,在第一个示例中,输入和输出的形状为 [2, 3],而在第二个示例中,输入和输出的形状为 [1, 1, 2, 3]

如果不是为了广播,那么这就是要介绍的全部内容。下面是考虑广播时的规则。我们从完全通用的角度对广播进行描述,并说明标量分布的简化:

  1. 定义 n = len(batch_shape) + len(event_shape)。(对于标量分布,len(event_shape)=0。)
  2. 如果输入张量 t 的维数小于 n,则通过在左侧添加大小为 1 的维度来填充其形状,直到其维数恰好为 n。调用生成的张量 t'
  3. 针对要为其计算 log_prob 的分布的 [batch_shape, event_shape] 广播 t'n 个最右侧维度。更详细地说:对于 t' 已经与分布匹配的维度,不执行任何操作,而对于 t' 具有单例的维度,将该单例复制适当的次数。任何其他情况都是错误。(对于标量分布,由于 event_shape = [],因此我们仅针对 batch_shape 进行广播。)
  4. 现在,我们终于可以计算 log_prob 了。所生成张量的形状为 [sample_shape, batch_shape],其中 sample_shape 被定义为 n 个最右侧维度左侧的 tt' 的任意维度:sample_shape = shape(t)[:-n]

如果您不知道它的含义是什么,那么可能有些混乱,因此我们来运行一些示例:

three_poissons.log_prob([10.])
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-16.104412 ,  -2.0785608, -69.05272  ], dtype=float32)>

张量 [10.](形状为 [1])在 batch_shape 为 3 的范围内广播,因此我们在值 10 处评估全部三个泊松分布的对数概率。

three_poissons.log_prob([[[1.], [10.]], [[100.], [1000.]]])
<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[-1.0000000e+00, -7.6974149e+00, -9.5394836e+01],
        [-1.6104412e+01, -2.0785608e+00, -6.9052719e+01]],

       [[-3.6473938e+02, -1.4348087e+02, -3.2223587e+00],
        [-5.9131279e+03, -3.6195427e+03, -1.4069575e+03]]], dtype=float32)>

在上面的示例中,输入张量的形状为 [2, 2, 1],而分布对象的批次形状为 3。因此,对于 [2, 2] 示例维度中的每个维度,提供的单个值会广播到这三个泊松分布。

一种可能有用的思考方式:由于 three_poissons 具有 batch_shape = [2, 3],因此,对 log_prob 的调用必须获取一个最后维度为 1 或 3 的张量;其他所有情况都是错误。(NumPy 广播规则将标量的特例视为完全等同于形状为 [1] 的张量。)

我们通过使用更复杂的泊松分布 (batch_shape = [2, 3]) 来测试我们的想法:

poisson_2_by_3 = tfd.Poisson(
    rate=[[1., 10., 100.,], [2., 20., 200.]],
    name='Two-by-Three Poissons')
poisson_2_by_3.log_prob(1.)
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[  -1.       ,   -7.697415 ,  -95.39484  ],
       [  -1.3068528,  -17.004269 , -194.70169  ]], dtype=float32)>
poisson_2_by_3.log_prob([1.])  # Exactly equivalent to above, demonstrating the scalar special case.
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[  -1.       ,   -7.697415 ,  -95.39484  ],
       [  -1.3068528,  -17.004269 , -194.70169  ]], dtype=float32)>
poisson_2_by_3.log_prob([[1., 1., 1.], [1., 1., 1.]])  # Another way to write the same thing. No broadcasting.
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[  -1.       ,   -7.697415 ,  -95.39484  ],
       [  -1.3068528,  -17.004269 , -194.70169  ]], dtype=float32)>
poisson_2_by_3.log_prob([[1., 10., 100.]])  # Input is [1, 3] broadcast to [2, 3].
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ -1.       ,  -2.0785608,  -3.2223587],
       [ -1.3068528,  -5.14709  , -33.90767  ]], dtype=float32)>
poisson_2_by_3.log_prob([[1., 10., 100.], [1., 10., 100.]])  # Equivalent to above. No broadcasting.
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ -1.       ,  -2.0785608,  -3.2223587],
       [ -1.3068528,  -5.14709  , -33.90767  ]], dtype=float32)>
poisson_2_by_3.log_prob([[1., 1., 1.], [2., 2., 2.]])  # No broadcasting.
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[  -1.       ,   -7.697415 ,  -95.39484  ],
       [  -1.3068528,  -14.701683 , -190.09653  ]], dtype=float32)>
poisson_2_by_3.log_prob([[1.], [2.]])  # Equivalent to above. Input shape [2, 1] broadcast to [2, 3].
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[  -1.       ,   -7.697415 ,  -95.39484  ],
       [  -1.3068528,  -14.701683 , -190.09653  ]], dtype=float32)>

上面的示例涉及在批次上广播,但样本形状为空。假设我们有一个值的集合,并且我们想要获得批次中每个点处每个值的对数概率。我们可以手动实现:

poisson_2_by_3.log_prob([[[1., 1., 1.], [1., 1., 1.]], [[2., 2., 2.], [2., 2., 2.]]])  # Input shape [2, 2, 3].
<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[  -1.       ,   -7.697415 ,  -95.39484  ],
        [  -1.3068528,  -17.004269 , -194.70169  ]],

       [[  -1.6931472,   -6.087977 ,  -91.48282  ],
        [  -1.3068528,  -14.701683 , -190.09653  ]]], dtype=float32)>

或者,我们可以让广播处理最后一个批次维度:

poisson_2_by_3.log_prob([[[1.], [1.]], [[2.], [2.]]])  # Input shape [2, 2, 1].
<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[  -1.       ,   -7.697415 ,  -95.39484  ],
        [  -1.3068528,  -17.004269 , -194.70169  ]],

       [[  -1.6931472,   -6.087977 ,  -91.48282  ],
        [  -1.3068528,  -14.701683 , -190.09653  ]]], dtype=float32)>

我们还可以(或许不太自然)让广播只处理第一个批次维度:

poisson_2_by_3.log_prob([[[1., 1., 1.]], [[2., 2., 2.]]])  # Input shape [2, 1, 3].
<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[  -1.       ,   -7.697415 ,  -95.39484  ],
        [  -1.3068528,  -17.004269 , -194.70169  ]],

       [[  -1.6931472,   -6.087977 ,  -91.48282  ],
        [  -1.3068528,  -14.701683 , -190.09653  ]]], dtype=float32)>

或者,我们可以让广播同时处理两个批次维度:

poisson_2_by_3.log_prob([[[1.]], [[2.]]])  # Input shape [2, 1, 1].
<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[  -1.       ,   -7.697415 ,  -95.39484  ],
        [  -1.3068528,  -17.004269 , -194.70169  ]],

       [[  -1.6931472,   -6.087977 ,  -91.48282  ],
        [  -1.3068528,  -14.701683 , -190.09653  ]]], dtype=float32)>

当我们只有两个需要的值时,上面的方式十分有效,但是,设想我们有一长串的值要在每个批次点处进行评估。为此,下面的符号可以在形状的右侧额外添加一个大小为 1 的维度,这样做非常有用:

poisson_2_by_3.log_prob(tf.constant([1., 2.])[..., tf.newaxis, tf.newaxis])
<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[  -1.       ,   -7.697415 ,  -95.39484  ],
        [  -1.3068528,  -17.004269 , -194.70169  ]],

       [[  -1.6931472,   -6.087977 ,  -91.48282  ],
        [  -1.3068528,  -14.701683 , -190.09653  ]]], dtype=float32)>

这是跨步切片符号的一个实例,非常值得了解。

为了保持完整性,回到 three_poissons,相同的示例如下所示:

three_poissons.log_prob([[1.], [10.], [50.], [100.]])
<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[  -1.       ,   -7.697415 ,  -95.39484  ],
       [ -16.104412 ,   -2.0785608,  -69.05272  ],
       [-149.47777  ,  -43.34851  ,  -18.219261 ],
       [-364.73938  , -143.48087  ,   -3.2223587]], dtype=float32)>
three_poissons.log_prob(tf.constant([1., 10., 50., 100.])[..., tf.newaxis])  # Equivalent to above.
<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[  -1.       ,   -7.697415 ,  -95.39484  ],
       [ -16.104412 ,   -2.0785608,  -69.05272  ],
       [-149.47777  ,  -43.34851  ,  -18.219261 ],
       [-364.73938  , -143.48087  ,   -3.2223587]], dtype=float32)>

多元分布

现在,我们转向具有非空事件形状的多元分布。我们看一下多项式分布。

multinomial_distributions = [
    # Multinomial is a vector-valued distribution: if we have k classes,
    # an individual sample from the distribution has k values in it, so the
    # event_shape is `[k]`.
    tfd.Multinomial(total_count=100., probs=[.5, .4, .1],
                    name='One Multinomial'),
    tfd.Multinomial(total_count=[100., 1000.], probs=[.5, .4, .1],
                    name='Two Multinomials Same Probs'),
    tfd.Multinomial(total_count=100., probs=[[.5, .4, .1], [.1, .2, .7]],
                    name='Two Multinomials Same Counts'),
    tfd.Multinomial(total_count=[100., 1000.],
                    probs=[[.5, .4, .1], [.1, .2, .7]],
                    name='Two Multinomials Different Everything')

]

describe_distributions(multinomial_distributions)
tfp.distributions.Multinomial("One_Multinomial", batch_shape=[], event_shape=[3], dtype=float32)
tfp.distributions.Multinomial("Two_Multinomials_Same_Probs", batch_shape=[2], event_shape=[3], dtype=float32)
tfp.distributions.Multinomial("Two_Multinomials_Same_Counts", batch_shape=[2], event_shape=[3], dtype=float32)
tfp.distributions.Multinomial("Two_Multinomials_Different_Everything", batch_shape=[2], event_shape=[3], dtype=float32)

请注意,在最后三个示例中,batch_shape 始终为 [2],但是我们可以使用广播来获得共享的 total_count 或共享的 probs(或者两者都没有),因为它们在后台被广播为具有相同的形状。

根据我们已知的信息,抽样非常简单:

describe_sample_tensor_shapes(multinomial_distributions, sample_shapes)
tfp.distributions.Multinomial("One_Multinomial", batch_shape=[], event_shape=[3], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 3)
Sample shape: 2
Returned sample tensor shape: (2, 3)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 3)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 3)

tfp.distributions.Multinomial("Two_Multinomials_Same_Probs", batch_shape=[2], event_shape=[3], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 2, 3)
Sample shape: 2
Returned sample tensor shape: (2, 2, 3)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 2, 3)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 2, 3)

tfp.distributions.Multinomial("Two_Multinomials_Same_Counts", batch_shape=[2], event_shape=[3], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 2, 3)
Sample shape: 2
Returned sample tensor shape: (2, 2, 3)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 2, 3)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 2, 3)

tfp.distributions.Multinomial("Two_Multinomials_Different_Everything", batch_shape=[2], event_shape=[3], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 2, 3)
Sample shape: 2
Returned sample tensor shape: (2, 2, 3)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 2, 3)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 2, 3)

计算对数概率同样非常简单。我们以一个对角多元正态分布作为示例。(多项式分布对广播不是十分友好,因为对计数和概率的约束意味着广播通常会产生不可接受的值。)我们将使用 2 个 3 维分布组成的批次,它们的均值相同但标度不同(标准差):

two_multivariate_normals = tfd.MultivariateNormalDiag(loc=[1., 2., 3.], scale_identity_multiplier=[1., 2.])
two_multivariate_normals
<tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag' batch_shape=[2] event_shape=[3] dtype=float32>

(请注意,尽管我们使用标度为恒等式倍数的分布,但这不是对它的限制;我们可以传递 scale 而不是 scale_identity_multiplier。)

现在,我们来评估每个批次点的均值和漂移均值处的对数概率:

two_multivariate_normals.log_prob([[[1., 2., 3.]], [[3., 4., 5.]]])  # Input has shape [2,1,3].
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-2.7568154, -4.836257 ],
       [-8.756816 , -6.336257 ]], dtype=float32)>

完全等效,我们可以使用 https://tensorflow.google.cn/api_docs/cc/class/tensorflow/ops/strided-slice在常量中间插入一个额外的 shape=1 维度:

two_multivariate_normals.log_prob(
    tf.constant([[1., 2., 3.], [3., 4., 5.]])[:, tf.newaxis, :])  # Equivalent to above.
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-2.7568154, -4.836257 ],
       [-8.756816 , -6.336257 ]], dtype=float32)>

另一方面,如果不插入额外的维度,则将 [1., 2., 3.] 传递给第一个批次点,将 [3., 4., 5.] 传递给第二个批次点:

two_multivariate_normals.log_prob(tf.constant([[1., 2., 3.], [3., 4., 5.]]))
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([-2.7568154, -6.336257 ], dtype=float32)>

形状操作技术

Reshape 双射器

Reshape 双射器可用于改变分布的 event_shape 的形状。我们来看一个示例:

six_way_multinomial = tfd.Multinomial(total_count=1000., probs=[.3, .25, .2, .15, .08, .02])
six_way_multinomial
<tfp.distributions.Multinomial 'Multinomial' batch_shape=[] event_shape=[6] dtype=float32>

我们创建了一个事件形状为 [6] 的多项式。借助 Reshape 双射器,我们可以将其视为事件形状为 [2, 3] 的分布。

Bijector 表示 \({\mathbb R}^n\) 的一个开放子集上的可微分一对一函数。Bijectors 适合与 TransformedDistribution 结合使用,后者会根据基础分布 \(p(y)\) 和表示 \(Y = g(X)\) 的 Bijector 对分布 \(p(x)\) 建模。我们来看一下它的实际运行:

transformed_multinomial = tfd.TransformedDistribution(
    distribution=six_way_multinomial,
    bijector=tfb.Reshape(event_shape_out=[2, 3]))
transformed_multinomial
<tfp.distributions.TransformedDistribution 'reshapeMultinomial' batch_shape=[] event_shape=[2, 3] dtype=float32>
six_way_multinomial.log_prob([500., 100., 100., 150., 100., 50.])
<tf.Tensor: shape=(), dtype=float32, numpy=-178.22021>
transformed_multinomial.log_prob([[500., 100., 100.], [150., 100., 50.]])
<tf.Tensor: shape=(), dtype=float32, numpy=-178.22021>

这是 Reshape 双射器模型唯一能做的事情:它不能将事件维度转换为批次维度,反之亦然。

Independent 分布

Independent 分布用于将不一定相同的(也称为一批)独立分布的集合视为单个分布。更简单地说,Independent 允许将 batch_shape 中的维度转换为 event_shape 中的维度。我们将举例说明:

two_by_five_bernoulli = tfd.Bernoulli(
    probs=[[.05, .1, .15, .2, .25], [.3, .35, .4, .45, .5]],
    name="Two By Five Bernoulli")
two_by_five_bernoulli
<tfp.distributions.Bernoulli 'Two_By_Five_Bernoulli' batch_shape=[2, 5] event_shape=[] dtype=int32>

我们可以认为这是一个具有相关正面概率的 2 × 5 硬币数组。我们评估特定的任意一组 1 和 0 的概率:

pattern = [[1., 0., 0., 1., 0.], [0., 0., 1., 1., 1.]]
two_by_five_bernoulli.log_prob(pattern)
<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[-2.9957323 , -0.10536052, -0.16251892, -1.609438  , -0.2876821 ],
       [-0.35667497, -0.4307829 , -0.9162907 , -0.7985077 , -0.6931472 ]],
      dtype=float32)>

我们可以使用 Independent 将其转换为两个不同的“5 次伯努利试验的集合”,如果我们想将给定模式中出现的“一行”抛硬币视为单个结果,这会十分有用:

two_sets_of_five = tfd.Independent(
    distribution=two_by_five_bernoulli,
    reinterpreted_batch_ndims=1,
    name="Two Sets Of Five")
two_sets_of_five
<tfp.distributions.Independent 'Two_Sets_Of_Five' batch_shape=[2] event_shape=[5] dtype=int32>

在数学上,我们通过将五次“独立”抛硬币的对数概率相加来计算每“组”五次抛硬币的对数概率,这就是分布名称的由来:

two_sets_of_five.log_prob(pattern)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([-5.160732 , -3.1954036], dtype=float32)>

我们可以更进一步,使用 Independent 来创建一个分布,其中各个事件是一组 2 × 5 伯努利试验:

one_set_of_two_by_five = tfd.Independent(
    distribution=two_by_five_bernoulli, reinterpreted_batch_ndims=2,
    name="One Set Of Two By Five")
one_set_of_two_by_five.log_prob(pattern)
<tf.Tensor: shape=(), dtype=float32, numpy=-8.356134>

值得注意的是,从 sample 的角度来看,使用 Independent 不会发生任何改变:

describe_sample_tensor_shapes(
    [two_by_five_bernoulli,
     two_sets_of_five,
     one_set_of_two_by_five],
    [[3, 5]])
tfp.distributions.Bernoulli("Two_By_Five_Bernoulli", batch_shape=[2, 5], event_shape=[], dtype=int32)
Sample shape: [3, 5]
Returned sample tensor shape: (3, 5, 2, 5)

tfp.distributions.Independent("Two_Sets_Of_Five", batch_shape=[2], event_shape=[5], dtype=int32)
Sample shape: [3, 5]
Returned sample tensor shape: (3, 5, 2, 5)

tfp.distributions.Independent("One_Set_Of_Two_By_Five", batch_shape=[], event_shape=[2, 5], dtype=int32)
Sample shape: [3, 5]
Returned sample tensor shape: (3, 5, 2, 5)

作为给读者的临别练习,我们建议从抽样和对数概率角度来考虑 Normal 分布的向量批次与 MultivariateNormalDiag 分布之间的差异和相似性。我们如何使用 Independent 从一个 Normal 批次构造一个 MultivariateNormalDiag?(请注意,MultivariateNormalDiag 实际上并未以这种方式实现。)