TensorFlow Distributions の形状を理解する

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード
import collections

import tensorflow as tf
tf.compat.v2.enable_v2_behavior()

import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

基礎

TensorFlow Distributions の形状には関連する 3 つの重要な概念があります。

  • イベントの形状は、分布からの 1 つの抽出の形状を表します。抽出は次元間で依存する場合があります。スカラー分布の場合、イベントの形状は [] です。5 次元の MultivariateNormal の場合、イベントの形状は [5] です。
  • バッチの形状は、独立した、同一に分布されていない抽出である「バッチ」の分布を表します。
  • サンプルの形状は、 分布ファミリからの独立した、同一に分布されたバッチの抽出を表します。

イベントの形状とバッチの形状は Distribution オブジェクトのプロパティですが、サンプルの形状は sample または log_prob への特定の呼び出しに関連付けられています。

このノートブックでは、例を使ってこれらの概念を説明していくので、すぐに分からなくても、心配する必要はありません。

また、これらの概念の概要については、このブログ記事を参照してください。

TensorFlow Eager に関する注意

このノートブックは、すべて TensorFlow Eager を使用して記述されています。提示された概念は Eager に依存していませんが、Eager では、Distribution オブジェクトが Python で作成されるときに、分布バッチとイベントの形状が評価されます(したがって既知です)。一方、グラフ(非 Eager モード)では、グラフが実行されるまでイベントとバッチの形状が決定されていない分布を定義することができます。

スカラー分布

上記のように、Distribution オブジェクトではイベントとバッチの形状が定義されています。まず、分布を説明するユーティリティから始めます。

def describe_distributions(distributions):
  print('\n'.join([str(d) for d in distributions]))

このセクションでは、スカラー分布(イベントの形状が [] の分布)について説明します。典型的な例は、rate で指定されたポアソン分布です。

poisson_distributions = [
    tfd.Poisson(rate=1., name='One Poisson Scalar Batch'),
    tfd.Poisson(rate=[1., 10., 100.], name='Three Poissons'),
    tfd.Poisson(rate=[[1., 10., 100.,], [2., 20., 200.]],
                name='Two-by-Three Poissons'),
    tfd.Poisson(rate=[1.], name='One Poisson Vector Batch'),
    tfd.Poisson(rate=[[1.]], name='One Poisson Expanded Batch')
]

describe_distributions(poisson_distributions)
tfp.distributions.Poisson("One_Poisson_Scalar_Batch", batch_shape=[], event_shape=[], dtype=float32)
tfp.distributions.Poisson("Three_Poissons", batch_shape=[3], event_shape=[], dtype=float32)
tfp.distributions.Poisson("Two_by_Three_Poissons", batch_shape=[2, 3], event_shape=[], dtype=float32)
tfp.distributions.Poisson("One_Poisson_Vector_Batch", batch_shape=[1], event_shape=[], dtype=float32)
tfp.distributions.Poisson("One_Poisson_Expanded_Batch", batch_shape=[1, 1], event_shape=[], dtype=float32)

ポアソン分布はスカラー分布であるため、そのイベントの形状は常に [] です。より多くのレートを指定すると、これらはバッチ形式で表示されます。例の最後のペアは興味深いものです。レートは 1 つだけですが、そのレートは空でない形状の numpy 配列に埋め込まれているため、その形状がバッチ形状になります。

標準の正規分布もスカラーです。イベントの形状は、ポアソンの場合と同じように [] ですが、ブロードキャストの最初の例で見ていきます。正規分布は、loc および scale パラメーターを使用して指定されます。

normal_distributions = [
    tfd.Normal(loc=0., scale=1., name='Standard'),
    tfd.Normal(loc=[0.], scale=1., name='Standard Vector Batch'),
    tfd.Normal(loc=[0., 1., 2., 3.], scale=1., name='Different Locs'),
    tfd.Normal(loc=[0., 1., 2., 3.], scale=[[1.], [5.]],
               name='Broadcasting Scale')
]

describe_distributions(normal_distributions)
tfp.distributions.Normal("Standard", batch_shape=[], event_shape=[], dtype=float32)
tfp.distributions.Normal("Standard_Vector_Batch", batch_shape=[1], event_shape=[], dtype=float32)
tfp.distributions.Normal("Different_Locs", batch_shape=[4], event_shape=[], dtype=float32)
tfp.distributions.Normal("Broadcasting_Scale", batch_shape=[2, 4], event_shape=[], dtype=float32)

上記の Broadcasting Scale 分布は興味深い例です。loc パラメーターは [4] の形状、scale パラメーターは [2, 1] の形状をもちます。Numpy ブロードキャストルールを使用すると、バッチ形状は [2, 4] になります。 "Broadcasting Scale" 分布を定義するための同等の(ただし、あまりエレガントではなく、推奨されない)方法は次のとおりです。

describe_distributions(
    [tfd.Normal(loc=[[0., 1., 2., 3], [0., 1., 2., 3.]],
                scale=[[1., 1., 1., 1.], [5., 5., 5., 5.]])])
tfp.distributions.Normal("Normal", batch_shape=[2, 4], event_shape=[], dtype=float32)

以上のようにブロードキャストの表記は頭痛やバグの原因にもなりますが便利です。

スカラー分布のサンプリング

分布で実行できる主なことは samplelog_prob の 2 つです。まず、サンプリングについて見ていきましょう。基本的なルールは、分布からサンプリングする場合、結果のテンソルは形状 [sample_shape, batch_shape, event_shape] になります。batch_shapeevent_shapeDistribution オブジェクトにより提供され、sample_shape は、sample の呼び出しにより提供されます。スカラー分布の場合、event_shape = [] であるため、サンプルから返されるテンソルの形状は [sample_shape, batch_shape] になります。では、試してみましょう。

def describe_sample_tensor_shape(sample_shape, distribution):
    print('Sample shape:', sample_shape)
    print('Returned sample tensor shape:',
          distribution.sample(sample_shape).shape)

def describe_sample_tensor_shapes(distributions, sample_shapes):
    started = False
    for distribution in distributions:
      print(distribution)
      for sample_shape in sample_shapes:
        describe_sample_tensor_shape(sample_shape, distribution)
      print()

sample_shapes = [1, 2, [1, 5], [3, 4, 5]]
describe_sample_tensor_shapes(poisson_distributions, sample_shapes)
tfp.distributions.Poisson("One_Poisson_Scalar_Batch", batch_shape=[], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1,)
Sample shape: 2
Returned sample tensor shape: (2,)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5)

tfp.distributions.Poisson("Three_Poissons", batch_shape=[3], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 3)
Sample shape: 2
Returned sample tensor shape: (2, 3)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 3)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 3)

tfp.distributions.Poisson("Two_by_Three_Poissons", batch_shape=[2, 3], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 2, 3)
Sample shape: 2
Returned sample tensor shape: (2, 2, 3)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 2, 3)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 2, 3)

tfp.distributions.Poisson("One_Poisson_Vector_Batch", batch_shape=[1], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 1)
Sample shape: 2
Returned sample tensor shape: (2, 1)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 1)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 1)

tfp.distributions.Poisson("One_Poisson_Expanded_Batch", batch_shape=[1, 1], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 1, 1)
Sample shape: 2
Returned sample tensor shape: (2, 1, 1)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 1, 1)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 1, 1)
describe_sample_tensor_shapes(normal_distributions, sample_shapes)
tfp.distributions.Normal("Standard", batch_shape=[], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1,)
Sample shape: 2
Returned sample tensor shape: (2,)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5)

tfp.distributions.Normal("Standard_Vector_Batch", batch_shape=[1], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 1)
Sample shape: 2
Returned sample tensor shape: (2, 1)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 1)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 1)

tfp.distributions.Normal("Different_Locs", batch_shape=[4], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 4)
Sample shape: 2
Returned sample tensor shape: (2, 4)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 4)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 4)

tfp.distributions.Normal("Broadcasting_Scale", batch_shape=[2, 4], event_shape=[], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 2, 4)
Sample shape: 2
Returned sample tensor shape: (2, 2, 4)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 2, 4)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 2, 4)

sample についての説明は以上です。返されたサンプルテンソルの形状は [sample_shape, batch_shape, event_shape] です。

スカラー分布の log_prob の計算

次に、log_prob を見てみましょう。これは少し注意する必要があります。log_prob は、分布の log_prob を計算する場所を表す(空でない)テンソルを入力として受け取ります。最も単純なケースでは、このテンソルは [sample_shape, batch_shape, event_shape] の形式になります。batch_shapeevent_shape は 分布のバッチおよびイベントの形状に一致します。スカラー分布の場合は、event_shape = [] なので、入力テンソルの形状は [sample_shape, batch_shape] です。この場合、[sample_shape, batch_shape] 形状のテンソルが返されます。

three_poissons = tfd.Poisson(rate=[1., 10., 100.], name='Three Poissons')
three_poissons
<tfp.distributions.Poisson 'Three_Poissons' batch_shape=[3] event_shape=[] dtype=float32>
three_poissons.log_prob([[1., 10., 100.], [100., 10., 1]])  # sample_shape is [2].
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[  -1.       ,   -2.0785608,   -3.2223587],
       [-364.73938  ,   -2.0785608,  -95.39484  ]], dtype=float32)>
three_poissons.log_prob([[[[1., 10., 100.], [100., 10., 1.]]]])  # sample_shape is [1, 1, 2].
<tf.Tensor: shape=(1, 1, 2, 3), dtype=float32, numpy=
array([[[[  -1.       ,   -2.0785608,   -3.2223587],
         [-364.73938  ,   -2.0785608,  -95.39484  ]]]], dtype=float32)>

最初の例では、入力と出力の形状が [2, 3] であり、2 番目の例では形状が [1, 1, 2, 3] であることに注意してください。

ブロードキャストがない場合はそれだけです。ブロードキャストを考慮する場合のルールは次のとおりです。これは一般的な説明であり、スカラー分布は簡略化されていることに注意してください。

  1. n = len(batch_shape) + len(event_shape) を定義します。(スカラー分布の場合は、len(event_shape)=0。)
  2. 入力テンソル t の次元が n 未満の場合、正確に n 次元になるまで、左側にサイズ 1 の次元を追加して形状をパッディングします。
  3. t' の右端の次元 nlog_prob 計算している分布の [batch_shape, event_shape] に対してブロードキャストします。詳しく説明すると、t' がすでに分布と一致している次元の場合は何もせず、t' の次元がシングルトンの場合は、そのシングルトンを適切な数で複製します。その他の場合はエラーです。(スカラー分布の場合、event_shape = [] であるため、 batch_shape に対してのみブロードキャストします。)
  4. これで、log_prob を計算できるようになりました。結果のテンソルの形状は、[sample_shape, batch_shape] です。sample_shape は、右端の次元 n の左側にある t または t' の任意の次元として定義されます(sample_shape = shape(t)[:-n])。

これが何を意味するのかわからないと混乱するかもしれないので、いくつかの例を見てみましょう。

three_poissons.log_prob([10.])
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-16.104412 ,  -2.0785608, -69.05272  ], dtype=float32)>

テンソル [10.] (形状 [1])は 3 つのbatch_shape でブロードキャストされるため、値 10 での 3 つのポワソンの対数確率をすべて評価します。

three_poissons.log_prob([[[1.], [10.]], [[100.], [1000.]]])
<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[-1.0000000e+00, -7.6974149e+00, -9.5394836e+01],
        [-1.6104412e+01, -2.0785608e+00, -6.9052719e+01]],

       [[-3.6473938e+02, -1.4348087e+02, -3.2223587e+00],
        [-5.9131279e+03, -3.6195427e+03, -1.4069575e+03]]], dtype=float32)>

上記の例では、入力テンソルの形状は [2, 2, 1] ですが、分布オブジェクトの形状は 3 です。したがって、[2, 2] サンプル次元のそれぞれについて、提供された単一の値は、3 つのポワソンのそれぞれにブロードキャストします。

これは役に立つ考え方です。three_poissons には batch_shape = [2, 3] があるため、log_prob の呼び出しには最後の次元が 1 または 3 のテンソルが必要です。それ以外はエラーです。(numpy ブロードキャストルールは、スカラーの特殊なケースを、形状 [1] のテンソルと完全に同等であるものとして扱います。)

では、batch_shape = [2, 3] を使用して、より複雑なポアソン分布を使用して試してみましょう。

poisson_2_by_3 = tfd.Poisson(
    rate=[[1., 10., 100.,], [2., 20., 200.]],
    name='Two-by-Three Poissons')
poisson_2_by_3.log_prob(1.)
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[  -1.       ,   -7.697415 ,  -95.39484  ],
       [  -1.3068528,  -17.004269 , -194.70169  ]], dtype=float32)>
poisson_2_by_3.log_prob([1.])  # Exactly equivalent to above, demonstrating the scalar special case.
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[  -1.       ,   -7.697415 ,  -95.39484  ],
       [  -1.3068528,  -17.004269 , -194.70169  ]], dtype=float32)>
poisson_2_by_3.log_prob([[1., 1., 1.], [1., 1., 1.]])  # Another way to write the same thing. No broadcasting.
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[  -1.       ,   -7.697415 ,  -95.39484  ],
       [  -1.3068528,  -17.004269 , -194.70169  ]], dtype=float32)>
poisson_2_by_3.log_prob([[1., 10., 100.]])  # Input is [1, 3] broadcast to [2, 3].
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ -1.       ,  -2.0785608,  -3.2223587],
       [ -1.3068528,  -5.14709  , -33.90767  ]], dtype=float32)>
poisson_2_by_3.log_prob([[1., 10., 100.], [1., 10., 100.]])  # Equivalent to above. No broadcasting.
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ -1.       ,  -2.0785608,  -3.2223587],
       [ -1.3068528,  -5.14709  , -33.90767  ]], dtype=float32)>
poisson_2_by_3.log_prob([[1., 1., 1.], [2., 2., 2.]])  # No broadcasting.
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[  -1.       ,   -7.697415 ,  -95.39484  ],
       [  -1.3068528,  -14.701683 , -190.09653  ]], dtype=float32)>
poisson_2_by_3.log_prob([[1.], [2.]])  # Equivalent to above. Input shape [2, 1] broadcast to [2, 3].
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[  -1.       ,   -7.697415 ,  -95.39484  ],
       [  -1.3068528,  -14.701683 , -190.09653  ]], dtype=float32)>

上記の例では、バッチを介したブロードキャストを見ていきましたが、サンプルの形状は空でした。値のコレクションがあり、バッチの各ポイントで各値の対数確率を取得する場合は、以下のように手動で実行できます。

poisson_2_by_3.log_prob([[[1., 1., 1.], [1., 1., 1.]], [[2., 2., 2.], [2., 2., 2.]]])  # Input shape [2, 2, 3].
<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[  -1.       ,   -7.697415 ,  -95.39484  ],
        [  -1.3068528,  -17.004269 , -194.70169  ]],

       [[  -1.6931472,   -6.087977 ,  -91.48282  ],
        [  -1.3068528,  -14.701683 , -190.09653  ]]], dtype=float32)>

または、ブロードキャストに最後のバッチ次元を処理させることもできます。

poisson_2_by_3.log_prob([[[1.], [1.]], [[2.], [2.]]])  # Input shape [2, 2, 1].
<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[  -1.       ,   -7.697415 ,  -95.39484  ],
        [  -1.3068528,  -17.004269 , -194.70169  ]],

       [[  -1.6931472,   -6.087977 ,  -91.48282  ],
        [  -1.3068528,  -14.701683 , -190.09653  ]]], dtype=float32)>

また、やや不自然ですがブロードキャストに最初のバッチ次元のみを処理させることもできます。

poisson_2_by_3.log_prob([[[1., 1., 1.]], [[2., 2., 2.]]])  # Input shape [2, 1, 3].
<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[  -1.       ,   -7.697415 ,  -95.39484  ],
        [  -1.3068528,  -17.004269 , -194.70169  ]],

       [[  -1.6931472,   -6.087977 ,  -91.48282  ],
        [  -1.3068528,  -14.701683 , -190.09653  ]]], dtype=float32)>

または、ブロードキャストに両方のバッチ次元を処理させることもできます。

poisson_2_by_3.log_prob([[[1.]], [[2.]]])  # Input shape [2, 1, 1].
<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[  -1.       ,   -7.697415 ,  -95.39484  ],
        [  -1.3068528,  -17.004269 , -194.70169  ]],

       [[  -1.6931472,   -6.087977 ,  -91.48282  ],
        [  -1.3068528,  -14.701683 , -190.09653  ]]], dtype=float32)>

上記は、必要な値が 2 つしかない場合は問題ありませんでした。しかし、すべてのバッチポイントで評価する値のリストが長い場合は、次の表記を使用します。形状の右側にサイズ 1 の余分な次元を追加すると、非常に便利です。

poisson_2_by_3.log_prob(tf.constant([1., 2.])[..., tf.newaxis, tf.newaxis])
<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[  -1.       ,   -7.697415 ,  -95.39484  ],
        [  -1.3068528,  -17.004269 , -194.70169  ]],

       [[  -1.6931472,   -6.087977 ,  -91.48282  ],
        [  -1.3068528,  -14.701683 , -190.09653  ]]], dtype=float32)>

これはストライドスライス表記のインスタンスであり、知っておく価値があります。

完全を期すために three_poissons に戻ると、同じ例は次のようになります。

three_poissons.log_prob([[1.], [10.], [50.], [100.]])
<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[  -1.       ,   -7.697415 ,  -95.39484  ],
       [ -16.104412 ,   -2.0785608,  -69.05272  ],
       [-149.47777  ,  -43.34851  ,  -18.219261 ],
       [-364.73938  , -143.48087  ,   -3.2223587]], dtype=float32)>
three_poissons.log_prob(tf.constant([1., 10., 50., 100.])[..., tf.newaxis])  # Equivalent to above.
<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[  -1.       ,   -7.697415 ,  -95.39484  ],
       [ -16.104412 ,   -2.0785608,  -69.05272  ],
       [-149.47777  ,  -43.34851  ,  -18.219261 ],
       [-364.73938  , -143.48087  ,   -3.2223587]], dtype=float32)>

多変量分布

ここでは、空でないイベント形状を持つ多変量分布を見ていきます。まず、多項分布を見てみましょう。

multinomial_distributions = [
    # Multinomial is a vector-valued distribution: if we have k classes,
    # an individual sample from the distribution has k values in it, so the
    # event_shape is `[k]`.
    tfd.Multinomial(total_count=100., probs=[.5, .4, .1],
                    name='One Multinomial'),
    tfd.Multinomial(total_count=[100., 1000.], probs=[.5, .4, .1],
                    name='Two Multinomials Same Probs'),
    tfd.Multinomial(total_count=100., probs=[[.5, .4, .1], [.1, .2, .7]],
                    name='Two Multinomials Same Counts'),
    tfd.Multinomial(total_count=[100., 1000.],
                    probs=[[.5, .4, .1], [.1, .2, .7]],
                    name='Two Multinomials Different Everything')

]

describe_distributions(multinomial_distributions)
tfp.distributions.Multinomial("One_Multinomial", batch_shape=[], event_shape=[3], dtype=float32)
tfp.distributions.Multinomial("Two_Multinomials_Same_Probs", batch_shape=[2], event_shape=[3], dtype=float32)
tfp.distributions.Multinomial("Two_Multinomials_Same_Counts", batch_shape=[2], event_shape=[3], dtype=float32)
tfp.distributions.Multinomial("Two_Multinomials_Different_Everything", batch_shape=[2], event_shape=[3], dtype=float32)

最後の 3 つの例では、batch_shape は常に [2] でしたが、ブロードキャストを使用して、共有する total_count または共有する probs 使用できます(または、使用しないこともできます)。内部では同じ形状になるようにブロードキャストされるためです。

既知の事柄を考慮すると、サンプリングは簡単です。

describe_sample_tensor_shapes(multinomial_distributions, sample_shapes)
tfp.distributions.Multinomial("One_Multinomial", batch_shape=[], event_shape=[3], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 3)
Sample shape: 2
Returned sample tensor shape: (2, 3)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 3)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 3)

tfp.distributions.Multinomial("Two_Multinomials_Same_Probs", batch_shape=[2], event_shape=[3], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 2, 3)
Sample shape: 2
Returned sample tensor shape: (2, 2, 3)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 2, 3)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 2, 3)

tfp.distributions.Multinomial("Two_Multinomials_Same_Counts", batch_shape=[2], event_shape=[3], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 2, 3)
Sample shape: 2
Returned sample tensor shape: (2, 2, 3)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 2, 3)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 2, 3)

tfp.distributions.Multinomial("Two_Multinomials_Different_Everything", batch_shape=[2], event_shape=[3], dtype=float32)
Sample shape: 1
Returned sample tensor shape: (1, 2, 3)
Sample shape: 2
Returned sample tensor shape: (2, 2, 3)
Sample shape: [1, 5]
Returned sample tensor shape: (1, 5, 2, 3)
Sample shape: [3, 4, 5]
Returned sample tensor shape: (3, 4, 5, 2, 3)

対数確率の計算も同様に簡単です。対角多変量正規分布の例を見てみましょう。(カウントと確率の制約により、ブロードキャストは許容できない値を生成することが多いため、多項分布はブロードキャストにあまり適していません。)平均は同じですがスケール(標準偏差)が異なる 2 つの 3 次元分布のバッチを使用します。

two_multivariate_normals = tfd.MultivariateNormalDiag(loc=[1., 2., 3.], scale_diag=tf.ones([2, 3]) * [[1.], [2.]])
two_multivariate_normals
<tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag' batch_shape=[2] event_shape=[3] dtype=float32>

次に、各バッチポイントの平均とシフトされた平均での対数確率を評価します。

two_multivariate_normals.log_prob([[[1., 2., 3.]], [[3., 4., 5.]]])  # Input has shape [2,1,3].
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-2.7568154, -4.836257 ],
       [-8.756816 , -6.336257 ]], dtype=float32)>

まったく同じように、[https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/strided-slice](ストライドスライス表記)を使用して、定数の中央に追加の形状 = 1 次元を挿入できます。

two_multivariate_normals.log_prob(
    tf.constant([[1., 2., 3.], [3., 4., 5.]])[:, tf.newaxis, :])  # Equivalent to above.
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-2.7568154, -4.836257 ],
       [-8.756816 , -6.336257 ]], dtype=float32)>

一方、余分な次元を追加しない場合は、[1., 2., 3.] を最初のバッチポイントに渡し、[3., 4., 5.] を 2 番目のバッチポイントに渡します。

two_multivariate_normals.log_prob(tf.constant([[1., 2., 3.], [3., 4., 5.]]))
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([-2.7568154, -6.336257 ], dtype=float32)>

形状変換テクニック

Reshape Bijector

Reshape Bijector を使用すると、分布の event_shape の形状を変換できます。以下に例を示します。

six_way_multinomial = tfd.Multinomial(total_count=1000., probs=[.3, .25, .2, .15, .08, .02])
six_way_multinomial
<tfp.distributions.Multinomial 'Multinomial' batch_shape=[] event_shape=[6] dtype=float32>

[6] のイベント形状を持つ多項分布を作成しました。Reshape Bijector を使用すると、これを [2, 3] のイベント形状を持つ分布として扱うことができます。

Bijector は、\({\mathbb R}^n\) の開集合上の微分可能な 1 対 1 の関数を表します。Bijectors は、TransformedDistribution と組み合わせて使用されます。これは、基本分布 \(p(x)\) および\(Y = g(X)\) を表す Bijector に関して分布 \(p(y)\) をモデル化します。では、実際に見てみましょう。

transformed_multinomial = tfd.TransformedDistribution(
    distribution=six_way_multinomial,
    bijector=tfb.Reshape(event_shape_out=[2, 3]))
transformed_multinomial
<tfp.distributions.TransformedDistribution 'reshapeMultinomial' batch_shape=[] event_shape=[2, 3] dtype=float32>
six_way_multinomial.log_prob([500., 100., 100., 150., 100., 50.])
<tf.Tensor: shape=(), dtype=float32, numpy=-178.21973>
transformed_multinomial.log_prob([[500., 100., 100.], [150., 100., 50.]])
<tf.Tensor: shape=(), dtype=float32, numpy=-178.21973>

これは、Reshape Bijector が実行できる唯一のことです。イベント次元をバッチ次元に、またはバッチ次元をイベント次元に変換することはできません。

Independent 分布

Independent 分布は、独立した、必ずしも同一ではない分布(バッチ)のコレクションを単一の分布として扱うために使用されます。より簡潔に言えば、Independent を使用すると、batch_shape の次元を event_shape の次元に変換できます。次に例を示します。

two_by_five_bernoulli = tfd.Bernoulli(
    probs=[[.05, .1, .15, .2, .25], [.3, .35, .4, .45, .5]],
    name="Two By Five Bernoulli")
two_by_five_bernoulli
<tfp.distributions.Bernoulli 'Two_By_Five_Bernoulli' batch_shape=[2, 5] event_shape=[] dtype=int32>

これは、表の確率が関連付けられた 2x5 のコインの配列として考えることができます。特定の任意の 1 と 0 のセットの確率を評価します。

pattern = [[1., 0., 0., 1., 0.], [0., 0., 1., 1., 1.]]
two_by_five_bernoulli.log_prob(pattern)
<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[-2.9957323 , -0.10536051, -0.16251892, -1.609438  , -0.2876821 ],
       [-0.35667497, -0.4307829 , -0.91629076, -0.79850775, -0.6931472 ]],
      dtype=float32)>

Independent を使用すると、これを 2 つの異なる「5 つのベルヌーイのセット」に変換できます。これは、特定のパターンで出現するコイントスの「行」を単一の結果と見なす場合に役立ちます。

two_sets_of_five = tfd.Independent(
    distribution=two_by_five_bernoulli,
    reinterpreted_batch_ndims=1,
    name="Two Sets Of Five")
two_sets_of_five
<tfp.distributions.Independent 'Two_Sets_Of_Five' batch_shape=[2] event_shape=[5] dtype=int32>

数学的には、5 つの「セット」ごとの対数確率を計算しています。セット内の 5 つの「独立した」コイントスの対数確率を合計するため、分布は「independent」と呼ばれます。

two_sets_of_five.log_prob(pattern)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([-5.160732 , -3.1954036], dtype=float32)>

さらに、Independent を使用して、個々のイベントが 2x5 のベルヌーイのセットである分布を作成できます。

one_set_of_two_by_five = tfd.Independent(
    distribution=two_by_five_bernoulli, reinterpreted_batch_ndims=2,
    name="One Set Of Two By Five")
one_set_of_two_by_five.log_prob(pattern)
<tf.Tensor: shape=(), dtype=float32, numpy=-8.356134>

sample の観点では、Independent を使用しても何も変更されないことに注意してください。

describe_sample_tensor_shapes(
    [two_by_five_bernoulli,
     two_sets_of_five,
     one_set_of_two_by_five],
    [[3, 5]])
tfp.distributions.Bernoulli("Two_By_Five_Bernoulli", batch_shape=[2, 5], event_shape=[], dtype=int32)
Sample shape: [3, 5]
Returned sample tensor shape: (3, 5, 2, 5)

tfp.distributions.Independent("Two_Sets_Of_Five", batch_shape=[2], event_shape=[5], dtype=int32)
Sample shape: [3, 5]
Returned sample tensor shape: (3, 5, 2, 5)

tfp.distributions.Independent("One_Set_Of_Two_By_Five", batch_shape=[], event_shape=[2, 5], dtype=int32)
Sample shape: [3, 5]
Returned sample tensor shape: (3, 5, 2, 5)

最後の演習として、サンプリングと対数確率の観点から、Normal 分布のベクトルバッチと MultivariateNormalDiag 分布の相違点と類似点を検討することをお勧めします。Independent を使用して、Normal のバッチから MultivariateNormalDiag を構築するにはどうすればよいでしょうか?(MultivariateNormalDiag は、実際にはこの方法で実装されていません。)