自動バッチ同時分布:穏やかなチュートリアル

TensorFlow.orgで表示GoogleColabで実行GitHubでソースを表示 ノートブックをダウンロード

序章

TensorFlow確率(TFP)は、多くの提供JointDistribution容易に近い数学的形式で確率的グラフィカルモデルを表現するためにユーザを可能にすることによって容易に確率推論を行う抽象化します。抽象化により、モデルからサンプリングし、モデルからのサンプルの対数確率を評価するためのメソッドが生成されます。このチュートリアルでは、オリジナルの後に開発された「autobatched」の変種、見直しJointDistribution抽象化を。元の自動バッチ処理されていない抽象化と比較して、自動バッチ処理されたバージョンは使用が簡単で人間工学的であるため、多くのモデルをより少ない定型文で表現できます。このコラボでは、(おそらく退屈な)詳細で単純なモデルを探索し、自動バッチ処理が解決する問題を明らかにし、(うまくいけば)途中でTFP形状の概念について読者にもっと教えます。

autobatchingの導入に先立ち、のいくつかの異なるバリエーションがあったJointDistribution :確率モデルを表現するためのさまざまな構文スタイルに対応する、 JointDistributionSequentialJointDistributionNamed 、およびJointDistributionCoroutine 。 Auobatchingはミックスインとして存在するので、我々は今持っているAutoBatchedこれらの全ての変異体。このチュートリアルでは、間の違い探るJointDistributionSequentialJointDistributionSequentialAutoBatched 。ただし、ここで行うことはすべて、基本的に変更なしで他のバリアントに適用できます。

依存関係と前提条件

インポートと設定

前提条件:ベイズ回帰問題

非常に単純なベイズ回帰シナリオを検討します。

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

このモデルでは、 mおよびb 、標準的な法線から引き出され、そして観察Yその平均確率変数に依存する正規分布から引き出されるmb 、及びいくつかの(非ランダム、公知)が共変量X 。 (簡単にするために、この例では、すべての確率変数のスケールが既知であると想定しています。)

このモデルで推論を実行するために、我々は両方共変量知っておく必要があるだろうXとの観測Y 、このチュートリアルの目的のために、我々は唯一の必要がありますX 、我々は、単純なダミーの定義ので、 X

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

デシデラタ

確率的推論では、2つの基本的な操作を実行したいことがよくあります。

  • sample :モデルからサンプルを描きます。
  • log_prob :モデルからサンプルのログ確率を計算します。

TFPのの重要な貢献JointDistribution抽象化(同様の確率的プログラミングには多くの他のアプローチは)ユーザーは、一度モデルを作成し、両方にアクセス持つことができるようにすることですsamplelog_prob計算を。

我々は、データセット(7点を持っていることは注目にX.shape = (7,)我々は今、優秀のための要望を述べることができるJointDistribution

  • sample()のリストを生成する必要がありTensors形状を有する[(), (), (7,)はそれぞれ、スカラー勾配、スカラーバイアス、およびベクトル観測値に対応します]。
  • log_prob(sample()) 、特定の傾きの対数確率、バイアス、および観察:スカラーを生成しなければなりません。
  • sample([5, 3])リスト生成すべきTensors形状を有する[(5, 3), (5, 3), (5, 3, 7)]を表す、 (5, 3) -からのサンプルのバッチをモデル。
  • log_prob(sample([5, 3]))生成すべきであるTensor形状(5、3)。

現在の連続を見てみましょうJointDistribution 、モデル上記の要望を実現する方法を見て、うまくいけば、TFPが道に沿って形状についてもう少し学びます。

スポイラー警告:追加の定型なし満たす上記の要望がされていることをアプローチautobatching

最初の試み; JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

これは、多かれ少なかれ、モデルをコードに直接変換したものです。勾配mとバイアスb簡単です。 Y使用して定義されたlambda -functionを:一般的なパターンは、ということであるlambdaの-function \(k\) 内の引数JointDistributionSequential (JDS)は以前使用 \(k\) モデルに分布。 「逆」の順序に注意してください。

我々は呼ぶことにしますsample_distributions 、リターンサンプルおよびサンプルを生成するために使用された基礎となる「サブディストリビューション」の両方を。 (私たちは、呼び出すことによって、単にサンプルを生産している可能性がsample 、後のチュートリアルでは、同様の分布を持っていると便利になります。)私たちが生産サンプルは結構です。

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

しかしlog_prob望ましくない形で結果を生成します。

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

また、複数のサンプリングは機能しません。

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

何が悪いのかを理解してみましょう。

簡単なレビュー:バッチとイベントの形状

TFPにおいては、通常の(ないJointDistribution )確率分布は、イベント形状バッチ形状を有し、その差を理解するTFPの効果的な利用に重要です。

  • イベント形状は、分布からの単一の描画の形状を表します。ドローは次元間で依存する場合があります。スカラー分布の場合、イベントの形状は[]です。 5次元のMultivariateNormalの場合、イベントの形状は[5]です。
  • バッチ形状は、独立した、同一に分散されていないドロー、別名「バッチ」の分布を表します。単一のPythonオブジェクトでディストリビューションのバッチを表すことは、TFPが大規模な効率を達成するための重要な方法の1つです。

我々の目的のために、心に留めておくべき重要な事実は、我々が呼び出した場合ということですlog_prob分布からのサンプルで、結果は常に一致するバッチ形状を(すなわち、右端の寸法として持っている)という形になります。

形状のより詳細な議論については、 「理解TensorFlow分布シェイプ」のチュートリアルを

しないのはなぜlog_prob(sample())スカラをプロデュース?

さんがで何が起こっているのか探求するバッチやイベント形状の知識を使ってみましょうlog_prob(sample())これが私たちのサンプルです:

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

そして、ここに私たちのディストリビューションがあります:

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

対数確率は、パーツの(一致した)要素でのサブ分布の対数確率を合計することによって計算されます。

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

したがって、説明の1つのレベルは、第3副成分ので、対数確率の計算は、7テンソルを返すことであるlog_prob_parts 7テンソルです。しかし、なぜ?

まあ、我々は最後の要素ことを確認dists超える物流に対応し、 Y mathematial製剤では、持っているbatch_shape[7]換言すれば、上物流Y (異なる手段と、この場合、同じスケールで)7面の独立法線のバッチです。

私たちは今、間違っているものを理解:JDSで、オーバー分布Yありbatch_shape=[7] 、JDSからのサンプルはのためのスカラー表しmbと7面の独立した法線の「バッチ」。そしてlog_prob描画の対数確率を表しそれぞれが7別個の対数確率、計算mbとシングル観測Y[i]一部にX[i]

固定log_prob(sample())Independent

ことを思い出しdists[2]有するevent_shape=[]batch_shape=[7]

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

TFPの使用してIndependentイベントの寸法にバッチ寸法に変換metadistributionを、我々は持つ分布にこれを変換することができますevent_shape=[7]batch_shape=[]私達はそれに名前を変更しますy_dist_i 、それは上の配布なのでYと、 _i立っ私たちのためにあるIndependent )ラッピング:

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

さて、 log_prob 7 -ベクトルのスカラーは次のとおりです。

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

カバーの下では、 Independentバッチオーバー合計:

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

そして実際、我々は新しい構築するためにこれを使用することができますjds_ii再びの略Independentlog_prob 、スカラーを返します:

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

いくつかのメモ:

  • jds_i.log_prob(s)と同じではないtf.reduce_sum(jds.log_prob(s))前者は、同時分布の「正しい」対数確率を生成します。 7-テンソル、ログの確率の和となっている各要素上後者の和mb 、との対数確率の単一元素Yので、overcounts mおよびb 。 ( log_prob(m) + log_prob(b) + log_prob(Y) 、むしろTFPがTFとnumpyのの放送規則に従うために例外をスローするよりも、結果を返し、ベクトルにスカラーを追加するベクトルの大きさの結果を生成します。)
  • この特定のケースでは、我々は問題を解決し、使用して同じ結果を達成している可能性がMultivariateNormalDiagの代わりに、 Independent(Normal(...)) MultivariateNormalDiagベクトル値の分布(すなわち、それは既にベクターイベント形状を有している)です。 Indeeed MultivariateNormalDiag可能性があります(ただしれていない)の組成物として実装されIndependentしてNormal 。これは、ベクトル所与のことを覚えておくことは価値があるV 、からサンプルn1 = Normal(loc=V)およびn2 = MultivariateNormalDiag(loc=V)区別できません。これらの分布beween差があることn1.log_prob(n1.sample())ベクトルでありn2.log_prob(n2.sample())スカラーです。

複数のサンプル?

複数のサンプルを描画しても機能しません。

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

理由を考えてみましょう。我々は呼び出すときjds_i.sample([5, 3]) 、我々は最初のためのサンプルを描くよmb 、形状のそれぞれ(5, 3)次に、我々は構築しようとしているNormalを経由して配信を:

tfd.Normal(loc=m*X + b, scale=1.)

しかし、場合m形状がある(5, 3)およびX形状がある7 、我々は乗算彼ら一緒にすることはできません、と確かにこれは我々がヒットしているエラーです。

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

この問題を解決するには、聞かせてのは、以上の分布何性質を考えるY有していなければなりません。私たちが呼ばれた場合jds_i.sample([5, 3])その後、私たちは知っているmb両方の形になります(5, 3)呼び出しは何形状べきsample上のY分布プロデュース?明白な答えは(5, 3, 7)各バッチポイントのために、私たちは同じサイズのサンプルたいX 。これは、TensorFlowのブロードキャスト機能を使用して、次のような次元を追加することで実現できます。

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

両方に軸を追加mb 、我々は、複数のサンプルをサポートする新たなJDSを定義することができます。

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

追加のチェックとして、単一のバッチポイントのログ確率が以前のログ確率と一致することを確認します。

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

勝利のためのオートバッチ処理

優れた!私たちは今、JointDistributionのバージョンを持っているのハンドルすべての私たちの要望: log_probリターンの使用にスカラおかげtfd.Independent 、および複数のサンプルは、我々は、余分な軸を追加することにより、放送固定することを動作するようになりました。

もっと簡単で良い方法があると言ったらどうしますか?あり、それは呼ばれていますJointDistributionSequentialAutoBatched (JDSAB):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

これはどのように作動しますか?あなたがしようとする可能性がありますが、コードを読んで深く理解するために、我々は、ほとんどのユースケースのために十分である簡潔な概要を与えるでしょう。

  • 私たちの最初の問題は、私たちの配布ということであったことを思い出しY持っていたbatch_shape=[7]してevent_shape=[]そして我々が使用しIndependentたイベントディメンションにバッチ寸法を変換します。 JDSABは、コンポーネント分布のバッチ形状を無視します。代わりに、それがであると仮定されたモデルの全体的な性質として、バッチ形状を扱う[]設定することにより、特に断りのない限りbatch_ndims > 0 )。効果は、我々は、手動で上記したように、イベント寸法に成分分布のすべてのバッチの大きさを変換するtfd.Independentを使用するのと同じです。
  • 私たちの第二の問題は、形状のマッサージの必要性だったmb彼らが適切に放送することができるようにX複数のサンプルを作成するとき。 JDSABを使用すると、単一のサンプルを生成するためのモデルを作成し、私たち「リフト」モデル全体がTensorFlowの使用して複数のサンプルを生成するvectorized_mapを。 (この機能は、JAXのにanalagousあるVMAP 。)

より詳細にバッチ形状の問題を探る、私たちは私たちのオリジナルの「悪い」共同配送のバッチ形状比較することができjds 、当社の一括固定分布jds_ijds_ia 、そして私たちのautobatched jds_ab

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

私たちは、元のことを見jds異なるバッチ形状のsubdistributionsを持っています。 jds_ijds_ia同じ(空の)バッチ形状でsubdistributionsを作成することによってこの問題を解決します。 jds_ab単一の(空の)バッチの形状を有しています。

ことは注目にそれの価値JointDistributionSequentialAutoBatched自由のためのいくつかの追加の汎用性を提供しています。私たちは共変量作ると仮定X (および、暗黙のうち、観測Y )2次元を:

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

私たちのJointDistributionSequentialAutoBatched変更なし(の形状ので、我々はモデルを再定義する必要があると連携Xでキャッシュされるjds_ab.log_prob ):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

一方、私たちは慎重に細工ないJointDistributionSequentialもはや機能します。

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

これを修正するために、我々は、第二追加する必要があるだろうtf.newaxis両方にmb形状に合わせて、増加reinterpreted_batch_ndims呼び出しで2にIndependent 。この場合、自動バッチ処理機械に形状の問題を処理させることは、より短く、より簡単で、より人間工学的です。

もう一度、私たちはこのノートPCは、探求しながら、ご注意JointDistributionSequentialAutoBatched 、他の変種JointDistribution同等持っAutoBatched 。 (のユーザーの場合JointDistributionCoroutineJointDistributionCoroutineAutoBatched追加の利点を持っているあなたに指定する必要のなくなっRootノードを、あなたは使ったことがない場合JointDistributionCoroutine 。あなたが安全にこの文を無視することができます)

結論

このノートブックでは、我々は導入JointDistributionSequentialAutoBatchedして詳細に単純な例を働きました。うまくいけば、TFPの形状と自動バッチ処理について何かを学んだことでしょう。