Otomatik Toplu Ortak Dağıtımlar: Nazik Bir Eğitim

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyin Not defterini indir

Tanıtım

TensorFlow Olasılık (TFP) bir sunmaktadır JointDistribution kolay kolay bir kullanım yakın bir matematiksel formda olasılıklı bir grafik modeli ifade izin vererek olasılıksal sonuç çıkarabilir Özetlenmesi; soyutlama, modelden örnekleme ve modelden örneklerin log olasılığını değerlendirme yöntemleri üretir. Bu eğitimde, orijinal sonra geliştirildi "autobatched" varyantları, gözden JointDistribution soyutlamalar. Orijinal, otomatik toplu olmayan soyutlamalara kıyasla, otomatik toplu sürümlerin kullanımı daha basit ve daha ergonomiktir, bu da birçok modelin daha az standart ile ifade edilmesini sağlar. Bu ortak çalışmada, otomatik yığınlamanın çözdüğü sorunları netleştirerek ve (umarız) okuyucuya yol boyunca TFP şekil kavramları hakkında daha fazla şey öğreterek basit bir modeli (belki de sıkıcı) ayrıntılarla keşfediyoruz.

Autobatching getirilmesinden önce, birkaç farklı varyantları yoktu JointDistribution : olasılık modelleri ifade etmek için farklı sözdizimsel stilleri karşılık gelen JointDistributionSequential , JointDistributionNamed ve JointDistributionCoroutine . Şu anda elimizde bu yüzden Auobatching, bir mixin olarak var AutoBatched Bütün bunların varyantları. Bu eğitimde, arasındaki farkları araştırmak JointDistributionSequential ve JointDistributionSequentialAutoBatched ; ancak burada yaptığımız her şey, esasen hiçbir değişiklik olmaksızın diğer varyantlara uygulanabilir.

Bağımlılıklar ve Ön Koşullar

İthalat ve kurulumlar

Önkoşul: Bir Bayesian Regresyon Problemi

Çok basit bir Bayes regresyon senaryosunu ele alacağız:

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

Bu modelde, m ve b , standart normallerden çizilir ve gözlemler Y olan ortalama rasgele değişkenlere bağlıdır normal dağılımdan çizilir m ve b , ve bazı (rastgele olmayan, bilinen) değişkenleri belirlemek X . (Basit olması için bu örnekte, tüm rastgele değişkenlerin ölçeğinin bilindiğini varsayıyoruz.)

Bu modeldeki çıkarım gerçekleştirmek için, her iki ortak değişkenlere bilmek gerekiyordu X ve gözlemleri Y , ama bu yazının amaçları için, sadece gerekir X basit bir kukla tanımlamak, böylece X :

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

arzu edilen

Olasılıksal çıkarımda, genellikle iki temel işlemi gerçekleştirmek isteriz:

  • sample : bir model örnekleri çizimi.
  • log_prob : modelinden bir numunenin günlük olasılığını hesaplanıyor.

TFP en önemli katkısı JointDistribution soyutlama (yanı sıra olasılık programlama yönelik pek çok yaklaşım) kullanıcıların kez bir model yazıp hem erişmesini sağlamaktır sample ve log_prob hesaplamaları.

Bizim veri kümesindeki (7 puana sahip olduğunu kaydeden X.shape = (7,) ), şimdi mükemmel bir için aranılan vasıfları ifade edebiliriz JointDistribution :

  • sample() bir listesini üretmek gerekir Tensors şekle sahip olan [(), (), (7,) , sırasıyla, skaler eğim, skaler önyargı ve vektör gözlemler tekabül].
  • log_prob(sample()) belirli bir eğim, duruş ve gözlemler günlük olasılığını: bir skalar üretmelidir.
  • sample([5, 3]) bir listesini üretmek gerekir Tensors şekle sahip olan [(5, 3), (5, 3), (5, 3, 7)] , bir temsil (5, 3) - numune toplu işletmeye modeli.
  • log_prob(sample([5, 3])) , bir üretmelidir Tensor şekli (5, 3) ile yıkanmıştır.

Şimdi bir arkaya bakacağız JointDistribution , modeller yukarıdaki aranılan vasıfları elde etmek nasıl görmek ve TFP yol boyunca şekillendiren hakkında umarım biraz daha öğreniyoruz.

Spoiler uyarısı: eklenen Demirbaş olmadan tatmin yukarıdaki istekler olduğunu yaklaşım autobatching .

İlk girişim; JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

Bu, modelin koda aşağı yukarı doğrudan çevirisidir. Şev m ve önyargı b basittir. Y bir ile tanımlanmakta olup, lambda taşımasının avantajlı: Genel deseni olmasıdır lambda bölgesinin taşımasının avantajlı \(k\) bir bağımsız değişken JointDistributionSequential (JDS) Önceki kullanan \(k\) modelinde dağılımları. "Ters" sıraya dikkat edin.

Biz arayacağım sample_distributions , döner bir örnek ve örnek oluşturmak için kullanıldı yatan "alt dağılımları" hem. (Biz arayarak sadece örnek üretilen olabilirdi sample , daha sonra öğretici o dağılımları olması uygun olacaktır da.) Ürettiğimiz örnek gayet:

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

Ancak log_prob istenmeyen bir şekle sahip bir sonuç üretir:

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

Ve çoklu örnekleme çalışmıyor:

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Neyin yanlış gittiğini anlamaya çalışalım.

Kısa Bir İnceleme: Parti ve Etkinlik Şekli

TFP, sıradan (bir JointDistribution ) olasılık dağılımı bir olay şeklinde ve bir yığın şekline ve farkı anlamak TFP etkin kullanımı için çok önemlidir var

  • Olay şekli, dağılımdan tek bir çekilişin şeklini tanımlar; beraberlik boyutlara bağlı olabilir. Skaler dağılımlar için olay şekli []'dir. 5 boyutlu MultivariateNormal için olay şekli [5]'dir.
  • Toplu şekil, bağımsız, aynı şekilde dağıtılmamış çekilişleri, yani bir dağılım "toplu"sunu tanımlar. Bir dağıtım grubunu tek bir Python nesnesinde temsil etmek, TFP'nin ölçekte verimliliğe ulaşmasının temel yollarından biridir.

Bizim için, akılda tutulması gereken kritik bir gerçektir diyoruz eğer olmasıdır log_prob bir dağılımından tek bir numune üzerinde, sonuç her zaman maçlar (yani en sağdaki boyutlar olarak vardır) bir şekle toplu bir şekle sahip olacaktır.

Şekillerin daha derinlemesine tartışma için, bkz "Anlama TensorFlow Dağılımları Şekiller" öğretici .

Neden mu değil log_prob(sample()) bir Skaler üretin?

Diyelim neler olup bittiğini keşfetmek için toplu ve olay şeklinin bilgimizi kullanmak log_prob(sample()) . İşte yine örneğimiz:

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

Ve işte dağılımlarımız:

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

Günlük olasılığı, parçaların (eşleşen) öğelerindeki alt dağılımların günlük olasılıklarının toplanmasıyla hesaplanır:

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

Yani, açıklama bir düzey üçüncü alt bileşen nedeniyle günlük olasılık hesaplama 7 Tensörü dönen olmasıdır log_prob_parts 7 tensör olduğunu. Ama neden?

Biz de son elemanı görüyoruz dists üzerinde eden dağılımına karşılık gelmektedir, Y mathematial formülasyonda, bir sahiptir batch_shape ait [7] . Diğer bir deyişle, fazla dağıtım Y (bu durumda, aynı ölçek içinde, farklı araçlar ile) 7 bağımsız normaller bir şeyin.

Şimdi neyin yanlış olduğunu anlamaya: JDS de, üzerinde dağıtım Y sahiptir batch_shape=[7] , JDS bir örnek için skalarlar temsil m ve b ve 7 bağımsız normaller bir "toplu". ve log_prob çekme log olasılığını temsil her biri 7 ayrı bir giriş olasılıkları, hesaplar m ve b ve tek bir gözlem Y[i] bazı X[i] .

Sabitleme log_prob(sample()) ile Independent

Hatırlatma; dists[2] sahip event_shape=[] ve batch_shape=[7] :

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

TFP en kullanarak Independent olay boyutlarına toplu boyutları dönüştürür dağıtımdır, biz bir dağılım halinde bu dönüştürme event_shape=[7] ve batch_shape=[] (biz yeniden adlandıracaksınız y_dist_i bunun bir dağıtım çünkü Y ile, _i ayakta bizim için de Independent ) sarma:

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

Şimdi, log_prob 7 vektörün skaler geçerli:

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

Kapakları altında Independent toplu üzerinde özetliyor:

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

Ve gerçekten de, yeni bir inşa için kullanabilir jds_i ( i tekrar açılımı Independent ) nerede log_prob bir sayıl döndürür:

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

Birkaç not:

  • jds_i.log_prob(s) ile aynı değildir tf.reduce_sum(jds.log_prob(s)) . İlki, ortak dağılımın "doğru" log olasılığını üretir. 7-Tensörün üzerinde ikinci miktarlar, her bir eleman olan bir günlük olasılık toplamıdır m , b , ve log olasılık tek bir öğe Y bu overcounts böylece, m ve b . ( log_prob(m) + log_prob(b) + log_prob(Y) TFP TF ve numpy en yayın kurallarını takip eder, çünkü bir durum atma yerine bir sonuç verir;., Bir vektör, bir skaler ilave bir vektör boyutlu bir sonuç üretir)
  • Bu özel durumda, biz sorunu çözüldü ve kullanan aynı sonucu elde olabilirdi MultivariateNormalDiag yerine Independent(Normal(...)) . MultivariateNormalDiag bir vektör değerli dağılımıdır (yani, zaten vektör olay şekline sahiptir). Indeeed MultivariateNormalDiag olabilir (ama değil) bir kompozisyon olarak uygulanan Independent ve Normal . O bir vektör verilen hatırlamak için faydalıdır V , numune n1 = Normal(loc=V) , ve n2 = MultivariateNormalDiag(loc=V) ayırt edilemez; Bu dağılımlar beween fark olmasıdır n1.log_prob(n1.sample()) bir vektör ve bir n2.log_prob(n2.sample()) bir skalerdir.

Çoklu Örnekler?

Birden çok örnek çizmek hala çalışmıyor:

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Nedenini düşünelim. Dediğimiz zaman jds_i.sample([5, 3]) , ilk için numune çizmek gerekir m ve b , şekil, her (5, 3) . Sonra, bir inşa etmeye deneyeceğiz Normal aracılığıyla dağıtımı:

tfd.Normal(loc=m*X + b, scale=1.)

Ama eğer m şekline sahiptir (5, 3) ve X şekline sahiptir 7 , biz onları birlikte çarpın olamaz ve gerçekten bu hata Bizler vurma geçerli:

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Bu sorunu çözmek için, bitti dağıtım özellikleri neler düşünelim Y sahiptir olması. Adlandırdığımız ettiyseniz jds_i.sample([5, 3]) , ardından bildiğimiz m ve b hem şekle sahip olacaktır (5, 3) . Bir çağrı Ne biçim için gereken sample üzerinde Y dağıtım üretmek? Uygun cevap (5, 3, 7) : Her bir parti noktası sağlamak için, aynı boyutta bir örnek istiyorum X . Bunu, TensorFlow'un yayın yeteneklerini kullanarak ve ekstra boyutlar ekleyerek başarabiliriz:

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

Hem bir eksen ekleme m ve b , yeni JDS tanımlayabilir destekler çok sayıda numune olduğu:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

Ek bir kontrol olarak, tek bir parti noktası için günlük olasılığının daha önce sahip olduğumuzla eşleştiğini doğrulayacağız:

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

Kazanmak İçin Otomatik Yığınlama

Harika! Şimdi JointDistribution bir sürüme sahip olmasını kolları tüm bizim istekler: log_prob döner kullanımına bir sayıl sayesinde tfd.Independent ve çoklu numuneler ekstra eksenlerini ekleyerek yayın sabit şimdi çalışır.

Ya sana daha kolay, daha iyi bir yol olduğunu söyleseydim? Orada olduğunu ve denir JointDistributionSequentialAutoBatched (JDSAB):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

Bu nasıl çalışıyor? Eğer teşebbüs olabilir iken kodunu okumak derin bir anlayış için, en kullanım durumları için yeterli kısa bir özetini vereceğiz:

  • Hatırlama bizim ilk sorun için dağıtım olmasıydı Y vardı batch_shape=[7] ve event_shape=[] ve kullandığımız Independent bir olay boyutuna toplu boyut dönüştürmek. JDSAB, bileşen dağılımlarının toplu şekillerini yok sayar; bunun yerine olduğu varsayılır modelinin bir genel özelliği olarak toplu şekil davranır [] (ayarlayarak, aksi belirtilmediği sürece batch_ndims > 0 ). Etki elle yukarıda yaptığımız gibi, olay boyutlara bileşen dağıtımlarının tüm toplu boyutlarını dönüştürme tfd.Independent kullanmaya eşdeğerdir.
  • İkinci sorun, şekil masaj için bir ihtiyaç olduğu m ve b bunlar ile uygun bir şekilde yayın böylece X birden fazla numune oluştururken. JDSAB ile, bir model tek örneği oluşturmak için yazmaya ve "asansör" tüm model TensorFlow en kullanarak birden örnekleri oluşturmak için vectorized_map . (Bu özellik Jax'in için biçime benzer olan VMAP .)

Daha ayrıntılı olarak toplu şekil sorunu keşfetmek, bizim orijinal "kötülük" ortak dağıtım toplu şekiller karşılaştırabilirsiniz jds , bizim toplu sabit dağılımlardan jds_i ve jds_ia ve bizim autobatched jds_ab :

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

Biz orijinal görüyoruz jds farklı toplu şekillerle subdistributions vardır. jds_i ve jds_ia aynı (boş) bir toplu şekli ile subdistributions oluşturarak sabitlenir. jds_ab yalnızca tek bir (boş) bir toplu bir şekle sahiptir.

Belirterek It değerinde JointDistributionSequentialAutoBatched ücretsiz bazı ek genelliği sunmaktadır. Biz covariates yapmak varsayalım X (ve dolaylı olarak, gözlemler Y ) iki boyutlu:

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

Bizim JointDistributionSequentialAutoBatched herhangi bir değişiklik (biz şekli nedeniyle modelini yeniden tanımlamak gerekir çalışır X tarafından önbelleğe jds_ab.log_prob ):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

Öte yandan, bizim özenle hazırlanmış JointDistributionSequential artık çalışır:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

Bunu düzeltmek için, biz ikinci eklemek olurdu tf.newaxis hem m ve b şekil ve artış maç reinterpreted_batch_ndims çağrısında 2'ye Independent . Bu durumda, otomatik harmanlama makinelerinin şekil sorunlarıyla ilgilenmesine izin vermek daha kısa, daha kolay ve daha ergonomiktir.

Bir kez daha, bu defter kesfedilmeyi ederken dikkat JointDistributionSequentialAutoBatched , diğer varyantları JointDistribution eşdeğer olması AutoBatched . (Kullanıcıları için JointDistributionCoroutine , JointDistributionCoroutineAutoBatched katkısı olmadığını belirtmek artık gerek Root düğümleri, siz hiç kullanmadım eğer JointDistributionCoroutine . Güvenle bu ifadeyi göz ardı edebilirsiniz)

Sonuç Düşünceleri

Bu defterin, biz tanıttı JointDistributionSequentialAutoBatched ve ayrıntılı olarak basit örnek üzerinden çalıştı. Umarım TFP şekilleri ve otomatik yığınlama hakkında bir şeyler öğrenmişsinizdir!