TensorFlow.org'da görüntüleyin | Google Colab'da çalıştırın | Kaynağı GitHub'da görüntüleyin | Not defterini indir |
Tanıtım
TensorFlow Olasılık (TFP) bir sunmaktadır JointDistribution
kolay kolay bir kullanım yakın bir matematiksel formda olasılıklı bir grafik modeli ifade izin vererek olasılıksal sonuç çıkarabilir Özetlenmesi; soyutlama, modelden örnekleme ve modelden örneklerin log olasılığını değerlendirme yöntemleri üretir. Bu eğitimde, orijinal sonra geliştirildi "autobatched" varyantları, gözden JointDistribution
soyutlamalar. Orijinal, otomatik toplu olmayan soyutlamalara kıyasla, otomatik toplu sürümlerin kullanımı daha basit ve daha ergonomiktir, bu da birçok modelin daha az standart ile ifade edilmesini sağlar. Bu ortak çalışmada, otomatik yığınlamanın çözdüğü sorunları netleştirerek ve (umarız) okuyucuya yol boyunca TFP şekil kavramları hakkında daha fazla şey öğreterek basit bir modeli (belki de sıkıcı) ayrıntılarla keşfediyoruz.
Autobatching getirilmesinden önce, birkaç farklı varyantları yoktu JointDistribution
: olasılık modelleri ifade etmek için farklı sözdizimsel stilleri karşılık gelen JointDistributionSequential
, JointDistributionNamed
ve JointDistributionCoroutine
. Şu anda elimizde bu yüzden Auobatching, bir mixin olarak var AutoBatched
Bütün bunların varyantları. Bu eğitimde, arasındaki farkları araştırmak JointDistributionSequential
ve JointDistributionSequentialAutoBatched
; ancak burada yaptığımız her şey, esasen hiçbir değişiklik olmaksızın diğer varyantlara uygulanabilir.
Bağımlılıklar ve Ön Koşullar
İthalat ve kurulumlar
import functools
import numpy as np
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
Önkoşul: Bir Bayesian Regresyon Problemi
Çok basit bir Bayes regresyon senaryosunu ele alacağız:
\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]
Bu modelde, m
ve b
, standart normallerden çizilir ve gözlemler Y
olan ortalama rasgele değişkenlere bağlıdır normal dağılımdan çizilir m
ve b
, ve bazı (rastgele olmayan, bilinen) değişkenleri belirlemek X
. (Basit olması için bu örnekte, tüm rastgele değişkenlerin ölçeğinin bilindiğini varsayıyoruz.)
Bu modeldeki çıkarım gerçekleştirmek için, her iki ortak değişkenlere bilmek gerekiyordu X
ve gözlemleri Y
, ama bu yazının amaçları için, sadece gerekir X
basit bir kukla tanımlamak, böylece X
:
X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])
arzu edilen
Olasılıksal çıkarımda, genellikle iki temel işlemi gerçekleştirmek isteriz:
-
sample
: bir model örnekleri çizimi. -
log_prob
: modelinden bir numunenin günlük olasılığını hesaplanıyor.
TFP en önemli katkısı JointDistribution
soyutlama (yanı sıra olasılık programlama yönelik pek çok yaklaşım) kullanıcıların kez bir model yazıp hem erişmesini sağlamaktır sample
ve log_prob
hesaplamaları.
Bizim veri kümesindeki (7 puana sahip olduğunu kaydeden X.shape = (7,)
), şimdi mükemmel bir için aranılan vasıfları ifade edebiliriz JointDistribution
:
-
sample()
bir listesini üretmek gerekirTensors
şekle sahip olan[(), (), (7,)
, sırasıyla, skaler eğim, skaler önyargı ve vektör gözlemler tekabül]. -
log_prob(sample())
belirli bir eğim, duruş ve gözlemler günlük olasılığını: bir skalar üretmelidir. -
sample([5, 3])
bir listesini üretmek gerekirTensors
şekle sahip olan[(5, 3), (5, 3), (5, 3, 7)]
, bir temsil(5, 3)
- numune toplu işletmeye modeli. -
log_prob(sample([5, 3]))
, bir üretmelidirTensor
şekli (5, 3) ile yıkanmıştır.
Şimdi bir arkaya bakacağız JointDistribution
, modeller yukarıdaki aranılan vasıfları elde etmek nasıl görmek ve TFP yol boyunca şekillendiren hakkında umarım biraz daha öğreniyoruz.
Spoiler uyarısı: eklenen Demirbaş olmadan tatmin yukarıdaki istekler olduğunu yaklaşım autobatching .
İlk girişim; JointDistributionSequential
jds = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
Bu, modelin koda aşağı yukarı doğrudan çevirisidir. Şev m
ve önyargı b
basittir. Y
bir ile tanımlanmakta olup, lambda
taşımasının avantajlı: Genel deseni olmasıdır lambda
bölgesinin taşımasının avantajlı \(k\) bir bağımsız değişken JointDistributionSequential
(JDS) Önceki kullanan \(k\) modelinde dağılımları. "Ters" sıraya dikkat edin.
Biz arayacağım sample_distributions
, döner bir örnek ve örnek oluşturmak için kullanıldı yatan "alt dağılımları" hem. (Biz arayarak sadece örnek üretilen olabilirdi sample
, daha sonra öğretici o dağılımları olması uygun olacaktır da.) Ürettiğimiz örnek gayet:
dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>, <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 , -7.308844 , -9.832693 ], dtype=float32)>]
Ancak log_prob
istenmeyen bir şekle sahip bir sonuç üretir:
jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy= array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684, -4.4368567, -4.480562 ], dtype=float32)>
Ve çoklu örnekleme çalışmıyor:
try:
jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
Neyin yanlış gittiğini anlamaya çalışalım.
Kısa Bir İnceleme: Parti ve Etkinlik Şekli
TFP, sıradan (bir JointDistribution
) olasılık dağılımı bir olay şeklinde ve bir yığın şekline ve farkı anlamak TFP etkin kullanımı için çok önemlidir var
- Olay şekli, dağılımdan tek bir çekilişin şeklini tanımlar; beraberlik boyutlara bağlı olabilir. Skaler dağılımlar için olay şekli []'dir. 5 boyutlu MultivariateNormal için olay şekli [5]'dir.
- Toplu şekil, bağımsız, aynı şekilde dağıtılmamış çekilişleri, yani bir dağılım "toplu"sunu tanımlar. Bir dağıtım grubunu tek bir Python nesnesinde temsil etmek, TFP'nin ölçekte verimliliğe ulaşmasının temel yollarından biridir.
Bizim için, akılda tutulması gereken kritik bir gerçektir diyoruz eğer olmasıdır log_prob
bir dağılımından tek bir numune üzerinde, sonuç her zaman maçlar (yani en sağdaki boyutlar olarak vardır) bir şekle toplu bir şekle sahip olacaktır.
Şekillerin daha derinlemesine tartışma için, bkz "Anlama TensorFlow Dağılımları Şekiller" öğretici .
Neden mu değil log_prob(sample())
bir Skaler üretin?
Diyelim neler olup bittiğini keşfetmek için toplu ve olay şeklinin bilgimizi kullanmak log_prob(sample())
. İşte yine örneğimiz:
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>, <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 , -7.308844 , -9.832693 ], dtype=float32)>]
Ve işte dağılımlarımız:
dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]
Günlük olasılığı, parçaların (eşleşen) öğelerindeki alt dağılımların günlük olasılıklarının toplanmasıyla hesaplanır:
log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>, <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014, -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>
Yani, açıklama bir düzey üçüncü alt bileşen nedeniyle günlük olasılık hesaplama 7 Tensörü dönen olmasıdır log_prob_parts
7 tensör olduğunu. Ama neden?
Biz de son elemanı görüyoruz dists
üzerinde eden dağılımına karşılık gelmektedir, Y
mathematial formülasyonda, bir sahiptir batch_shape
ait [7]
. Diğer bir deyişle, fazla dağıtım Y
(bu durumda, aynı ölçek içinde, farklı araçlar ile) 7 bağımsız normaller bir şeyin.
Şimdi neyin yanlış olduğunu anlamaya: JDS de, üzerinde dağıtım Y
sahiptir batch_shape=[7]
, JDS bir örnek için skalarlar temsil m
ve b
ve 7 bağımsız normaller bir "toplu". ve log_prob
çekme log olasılığını temsil her biri 7 ayrı bir giriş olasılıkları, hesaplar m
ve b
ve tek bir gözlem Y[i]
bazı X[i]
.
Sabitleme log_prob(sample())
ile Independent
Hatırlatma; dists[2]
sahip event_shape=[]
ve batch_shape=[7]
:
dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>
TFP en kullanarak Independent
olay boyutlarına toplu boyutları dönüştürür dağıtımdır, biz bir dağılım halinde bu dönüştürme event_shape=[7]
ve batch_shape=[]
(biz yeniden adlandıracaksınız y_dist_i
bunun bir dağıtım çünkü Y
ile, _i
ayakta bizim için de Independent
) sarma:
y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>
Şimdi, log_prob
7 vektörün skaler geçerli:
y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>
Kapakları altında Independent
toplu üzerinde özetliyor:
y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
Ve gerçekten de, yeni bir inşa için kullanabilir jds_i
( i
tekrar açılımı Independent
) nerede log_prob
bir sayıl döndürür:
jds_i = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m*X + b, scale=1.),
reinterpreted_batch_ndims=1)
])
jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>
Birkaç not:
-
jds_i.log_prob(s)
ile aynı değildirtf.reduce_sum(jds.log_prob(s))
. İlki, ortak dağılımın "doğru" log olasılığını üretir. 7-Tensörün üzerinde ikinci miktarlar, her bir eleman olan bir günlük olasılık toplamıdırm
,b
, ve log olasılık tek bir öğeY
bu overcounts böylece,m
veb
. (log_prob(m) + log_prob(b) + log_prob(Y)
TFP TF ve numpy en yayın kurallarını takip eder, çünkü bir durum atma yerine bir sonuç verir;., Bir vektör, bir skaler ilave bir vektör boyutlu bir sonuç üretir) - Bu özel durumda, biz sorunu çözüldü ve kullanan aynı sonucu elde olabilirdi
MultivariateNormalDiag
yerineIndependent(Normal(...))
.MultivariateNormalDiag
bir vektör değerli dağılımıdır (yani, zaten vektör olay şekline sahiptir). IndeeedMultivariateNormalDiag
olabilir (ama değil) bir kompozisyon olarak uygulananIndependent
veNormal
. O bir vektör verilen hatırlamak için faydalıdırV
, numunen1 = Normal(loc=V)
, ven2 = MultivariateNormalDiag(loc=V)
ayırt edilemez; Bu dağılımlar beween fark olmasıdırn1.log_prob(n1.sample())
bir vektör ve birn2.log_prob(n2.sample())
bir skalerdir.
Çoklu Örnekler?
Birden çok örnek çizmek hala çalışmıyor:
try:
jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
Nedenini düşünelim. Dediğimiz zaman jds_i.sample([5, 3])
, ilk için numune çizmek gerekir m
ve b
, şekil, her (5, 3)
. Sonra, bir inşa etmeye deneyeceğiz Normal
aracılığıyla dağıtımı:
tfd.Normal(loc=m*X + b, scale=1.)
Ama eğer m
şekline sahiptir (5, 3)
ve X
şekline sahiptir 7
, biz onları birlikte çarpın olamaz ve gerçekten bu hata Bizler vurma geçerli:
m = tfd.Normal(0., 1.).sample([5, 3])
try:
m * X
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
Bu sorunu çözmek için, bitti dağıtım özellikleri neler düşünelim Y
sahiptir olması. Adlandırdığımız ettiyseniz jds_i.sample([5, 3])
, ardından bildiğimiz m
ve b
hem şekle sahip olacaktır (5, 3)
. Bir çağrı Ne biçim için gereken sample
üzerinde Y
dağıtım üretmek? Uygun cevap (5, 3, 7)
: Her bir parti noktası sağlamak için, aynı boyutta bir örnek istiyorum X
. Bunu, TensorFlow'un yayın yeteneklerini kullanarak ve ekstra boyutlar ekleyerek başarabiliriz:
m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])
Hem bir eksen ekleme m
ve b
, yeni JDS tanımlayabilir destekler çok sayıda numune olduğu:
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-1.1133379 , 0.16390413, -0.24177533], [-1.1312429 , -0.6224666 , -1.8182136 ], [-0.31343174, -0.32932565, 0.5164407 ], [-0.0119963 , -0.9079621 , 2.3655841 ], [-0.26293617, 0.8229698 , 0.31098196]], dtype=float32)>, <tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-0.02876974, 1.0872147 , 1.0138507 ], [ 0.27367726, -1.331534 , -0.09084719], [ 1.3349475 , -0.68765205, 1.680652 ], [ 0.75436825, 1.3050154 , -0.9415123 ], [-1.2502679 , -0.25730947, 0.74611956]], dtype=float32)>, <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy= array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00, -4.8197951e+00, -5.2986512e+00, -6.6931367e+00], [ 3.6438566e-01, 1.0067395e+00, 1.4542470e+00, 8.1155670e-01, 1.8868095e+00, 2.3877139e+00, 1.0195159e+00], [-8.3624744e-01, 1.2518480e+00, 1.0943471e+00, 1.3052304e+00, -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]], [[-3.1788960e-01, 9.2615485e-03, -3.0963073e+00, -2.2846246e+00, -3.2269263e+00, -6.0213070e+00, -7.4806519e+00], [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00, -4.5065498e+00, -5.6719379e+00, -4.8012795e+00], [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00, -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]], [[ 2.0621397e+00, 3.4639853e-01, 7.0252883e-01, -1.4311566e+00, 3.3790007e+00, 1.1619035e+00, -8.9105040e-01], [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00, -2.7150445e+00, -2.4633870e+00, -2.1841538e+00], [ 7.7627432e-01, 2.2401071e+00, 3.7601702e+00, 2.4245868e+00, 4.0690269e+00, 4.0605016e+00, 5.1753912e+00]], [[ 1.4275590e+00, 3.3346462e+00, 1.5374103e+00, -2.2849756e-01, 9.1219616e-01, -3.1220305e-01, -3.2643962e-01], [-3.1910419e-02, -3.8848895e-01, 9.9946201e-02, -2.3619974e+00, -1.8507402e+00, -3.6830821e+00, -5.4907336e+00], [-7.1941972e-02, 2.1602919e+00, 4.9575748e+00, 4.2317696e+00, 9.3528280e+00, 1.0526063e+01, 1.5262107e+01]], [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00, -3.2361765e+00, -3.3434000e+00, -2.6849220e+00], [ 1.5006512e-02, -1.9866472e-01, 7.6781356e-01, 1.6228745e+00, 1.4191239e+00, 2.6655579e+00, 4.4663467e+00], [ 2.6599693e+00, 1.2663836e+00, 1.7162113e+00, 1.4839669e+00, 2.0559487e+00, 2.5976877e+00, 2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-12.483114 , -10.139662 , -11.514159 ], [-11.656767 , -17.201958 , -12.132455 ], [-17.838818 , -9.474525 , -11.24898 ], [-13.95219 , -12.490049 , -17.123957 ], [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>
Ek bir kontrol olarak, tek bir parti noktası için günlük olasılığının daha önce sahip olduğumuzla eşleştiğini doğrulayacağız:
(jds_ia.log_prob(shaped_sample)[3, 1] -
jds_i.log_prob([shaped_sample[0][3, 1],
shaped_sample[1][3, 1],
shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
Kazanmak İçin Otomatik Yığınlama
Harika! Şimdi JointDistribution bir sürüme sahip olmasını kolları tüm bizim istekler: log_prob
döner kullanımına bir sayıl sayesinde tfd.Independent
ve çoklu numuneler ekstra eksenlerini ekleyerek yayın sabit şimdi çalışır.
Ya sana daha kolay, daha iyi bir yol olduğunu söyleseydim? Orada olduğunu ve denir JointDistributionSequentialAutoBatched
(JDSAB):
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-12.191533 , -10.43885 , -16.371655 ], [-13.292994 , -11.97949 , -16.788685 ], [-15.987699 , -13.435732 , -10.6029 ], [-10.184758 , -11.969714 , -14.275676 ], [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=float32)>
Bu nasıl çalışıyor? Eğer teşebbüs olabilir iken kodunu okumak derin bir anlayış için, en kullanım durumları için yeterli kısa bir özetini vereceğiz:
- Hatırlama bizim ilk sorun için dağıtım olmasıydı
Y
vardıbatch_shape=[7]
veevent_shape=[]
ve kullandığımızIndependent
bir olay boyutuna toplu boyut dönüştürmek. JDSAB, bileşen dağılımlarının toplu şekillerini yok sayar; bunun yerine olduğu varsayılır modelinin bir genel özelliği olarak toplu şekil davranır[]
(ayarlayarak, aksi belirtilmediği sürecebatch_ndims > 0
). Etki elle yukarıda yaptığımız gibi, olay boyutlara bileşen dağıtımlarının tüm toplu boyutlarını dönüştürme tfd.Independent kullanmaya eşdeğerdir. - İkinci sorun, şekil masaj için bir ihtiyaç olduğu
m
veb
bunlar ile uygun bir şekilde yayın böyleceX
birden fazla numune oluştururken. JDSAB ile, bir model tek örneği oluşturmak için yazmaya ve "asansör" tüm model TensorFlow en kullanarak birden örnekleri oluşturmak için vectorized_map . (Bu özellik Jax'in için biçime benzer olan VMAP .)
Daha ayrıntılı olarak toplu şekil sorunu keşfetmek, bizim orijinal "kötülük" ortak dağıtım toplu şekiller karşılaştırabilirsiniz jds
, bizim toplu sabit dağılımlardan jds_i
ve jds_ia
ve bizim autobatched jds_ab
:
jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])
Biz orijinal görüyoruz jds
farklı toplu şekillerle subdistributions vardır. jds_i
ve jds_ia
aynı (boş) bir toplu şekli ile subdistributions oluşturarak sabitlenir. jds_ab
yalnızca tek bir (boş) bir toplu bir şekle sahiptir.
Belirterek It değerinde JointDistributionSequentialAutoBatched
ücretsiz bazı ek genelliği sunmaktadır. Biz covariates yapmak varsayalım X
(ve dolaylı olarak, gözlemler Y
) iki boyutlu:
X = np.arange(14).reshape((2, 7))
X
array([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13]])
Bizim JointDistributionSequentialAutoBatched
herhangi bir değişiklik (biz şekli nedeniyle modelini yeniden tanımlamak gerekir çalışır X
tarafından önbelleğe jds_ab.log_prob
):
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[ 0.1813647 , -0.85994506, 0.27593774], [-0.73323774, 1.1153806 , 0.8841938 ], [ 0.5127983 , -0.29271227, 0.63733214], [ 0.2362284 , -0.919168 , 1.6648189 ], [ 0.26317367, 0.73077047, 2.5395133 ]], dtype=float32)>, <tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[ 0.09636458, 2.0138032 , -0.5054413 ], [ 0.63941646, -1.0785882 , -0.6442188 ], [ 1.2310615 , -0.3293852 , 0.77637213], [ 1.2115169 , -0.98906034, -0.07816773], [-1.1318136 , 0.510014 , 1.036522 ]], dtype=float32)>, <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy= array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01, 8.5992378e-01, -5.3123581e-01, 3.1584005e+00, 2.9044402e+00], [-2.5645006e-01, 3.1554163e-01, 3.1186538e+00, 1.4272424e+00, 1.2843871e+00, 1.2266440e+00, 1.2798605e+00]], [[ 1.5973477e+00, -5.3631151e-01, 6.8143606e-03, -1.4910895e+00, -2.1568544e+00, -2.0513713e+00, -3.1663666e+00], [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00, -5.6543546e+00, -7.2378774e+00, -8.1577444e+00, -9.3582869e+00]], [[-2.1233239e+00, 5.8853775e-02, 1.2024102e+00, 1.6622503e+00, -1.9197327e-01, 1.8647723e+00, 6.4322817e-01], [ 3.7549341e-01, 1.5853541e+00, 2.4594500e+00, 2.1952972e+00, 1.7517658e+00, 2.9666045e+00, 2.5468128e+00]]], [[[ 8.9906776e-01, 6.7375046e-01, 7.3354661e-01, -9.9894643e-01, -3.4606690e+00, -3.4810467e+00, -4.4315586e+00], [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00, -6.8091092e+00, -7.7134805e+00, -8.6319380e+00, -8.6904278e+00]], [[-2.2462025e+00, -3.3060855e-01, 1.8974400e-01, 3.1422038e+00, 4.1483402e+00, 3.5642972e+00, 4.8709240e+00], [ 4.7880130e+00, 5.8790064e+00, 9.6695948e+00, 7.8112822e+00, 1.2022618e+01, 1.2411858e+01, 1.4323385e+01]], [[-1.0189297e+00, -7.8115642e-01, 1.6466728e+00, 8.2378983e-01, 3.0765080e+00, 3.0170646e+00, 5.1899948e+00], [ 6.5285158e+00, 7.8038850e+00, 6.4155884e+00, 9.0899811e+00, 1.0040427e+01, 9.1404457e+00, 1.0411951e+01]]], [[[ 4.5557004e-01, 1.4905317e+00, 1.4904103e+00, 2.9777462e+00, 2.8620450e+00, 3.4745665e+00, 3.8295493e+00], [ 3.9977460e+00, 5.7173767e+00, 7.8421035e+00, 6.3180594e+00, 6.0838981e+00, 8.2257290e+00, 9.6548376e+00]], [[-7.0750320e-01, -3.5972297e-01, 4.3136525e-01, -2.3301599e+00, -5.0374687e-01, -2.8338656e+00, -3.4453444e+00], [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00, -4.0196013e+00, -5.8831010e+00, -4.2965469e+00, -4.1388311e+00]], [[ 2.1969774e+00, 2.4614549e+00, 2.2314475e+00, 1.8392437e+00, 2.8367062e+00, 4.8600502e+00, 4.2273531e+00], [ 6.1879644e+00, 5.1792760e+00, 6.1141996e+00, 5.6517797e+00, 8.9979610e+00, 7.5938139e+00, 9.7918644e+00]]], [[[ 1.5249090e+00, 1.1388919e+00, 8.6903995e-01, 3.0762129e+00, 1.5128503e+00, 3.5204377e+00, 2.4760864e+00], [ 3.4166217e+00, 3.5930209e+00, 3.1694956e+00, 4.5797420e+00, 4.5271711e+00, 2.8774328e+00, 4.7288942e+00]], [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00, -3.8594103e+00, -4.9681158e+00, -6.4256043e+00, -5.5345035e+00], [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00, -1.0417805e+01, -1.1727266e+01, -1.1196255e+01, -1.1333830e+01]], [[-7.0419472e-01, 1.4568675e+00, 3.7946482e+00, 4.8489718e+00, 6.6498446e+00, 9.0224218e+00, 1.1153137e+01], [ 1.0060651e+01, 1.1998097e+01, 1.5326431e+01, 1.7957514e+01, 1.8323889e+01, 2.0160881e+01, 2.1269085e+01]]], [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01, 2.3558271e-01, -1.0381399e+00, 1.9387857e+00, -3.3694571e-01], [ 1.6015106e-01, 1.5284677e+00, -4.8567140e-01, -1.7770648e-01, 2.1919653e+00, 1.3015286e+00, 1.3877077e+00]], [[ 1.3688663e+00, 2.6602898e+00, 6.6657305e-01, 4.6554832e+00, 5.7781887e+00, 4.9115267e+00, 4.8446012e+00], [ 5.1983776e+00, 6.2297459e+00, 6.3848300e+00, 8.4291229e+00, 7.1309576e+00, 1.0395646e+01, 8.5736713e+00]], [[ 1.2675294e+00, 5.2844582e+00, 5.1331611e+00, 8.9993315e+00, 1.0794343e+01, 1.4039831e+01, 1.5731170e+01], [ 1.9084715e+01, 2.2191265e+01, 2.3481146e+01, 2.5803375e+01, 2.8632090e+01, 3.0234968e+01, 3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-28.90071 , -23.052422, -19.851362], [-19.775568, -25.894997, -20.302256], [-21.10754 , -23.667885, -20.973007], [-19.249458, -20.87892 , -20.573763], [-22.351208, -25.457762, -24.648403]], dtype=float32)>
Öte yandan, bizim özenle hazırlanmış JointDistributionSequential
artık çalışır:
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
try:
jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]
Bunu düzeltmek için, biz ikinci eklemek olurdu tf.newaxis
hem m
ve b
şekil ve artış maç reinterpreted_batch_ndims
çağrısında 2'ye Independent
. Bu durumda, otomatik harmanlama makinelerinin şekil sorunlarıyla ilgilenmesine izin vermek daha kısa, daha kolay ve daha ergonomiktir.
Bir kez daha, bu defter kesfedilmeyi ederken dikkat JointDistributionSequentialAutoBatched
, diğer varyantları JointDistribution
eşdeğer olması AutoBatched
. (Kullanıcıları için JointDistributionCoroutine
, JointDistributionCoroutineAutoBatched
katkısı olmadığını belirtmek artık gerek Root
düğümleri, siz hiç kullanmadım eğer JointDistributionCoroutine
. Güvenle bu ifadeyi göz ardı edebilirsiniz)
Sonuç Düşünceleri
Bu defterin, biz tanıttı JointDistributionSequentialAutoBatched
ve ayrıntılı olarak basit örnek üzerinden çalıştı. Umarım TFP şekilleri ve otomatik yığınlama hakkında bir şeyler öğrenmişsinizdir!