Distribusi Gabungan Batch Otomatis: Tutorial Lembut

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

pengantar

TensorFlow Probabilitas (TFP) menawarkan sejumlah JointDistribution abstraksi yang membuat inferensi probabilistik lebih mudah dengan memungkinkan pengguna untuk dengan mudah mengekspresikan model grafis probabilistik dalam bentuk matematika dekat-; abstraksi menghasilkan metode untuk pengambilan sampel dari model dan mengevaluasi probabilitas log sampel dari model. Dalam tutorial ini, kita meninjau "autobatched" varian, yang dikembangkan setelah asli JointDistribution abstraksi. Dibandingkan dengan abstraksi asli, non-autobatched, versi autobatched lebih mudah digunakan dan lebih ergonomis, memungkinkan banyak model diekspresikan dengan lebih sedikit boilerplate. Dalam colab ini, kami mengeksplorasi model sederhana dalam detail (mungkin membosankan), memperjelas masalah yang diselesaikan autobatching, dan (semoga) mengajari pembaca lebih banyak tentang konsep bentuk TFP di sepanjang jalan.

Sebelum pengenalan autobatching, ada varian yang berbeda dari JointDistribution , sesuai dengan gaya sintaksis yang berbeda untuk mengekspresikan model probabilistik: JointDistributionSequential , JointDistributionNamed , dan JointDistributionCoroutine . Auobatching ada sebagai mixin, jadi kita sekarang memiliki AutoBatched varian dari semua ini. Dalam tutorial ini, kita mengeksplorasi perbedaan antara JointDistributionSequential dan JointDistributionSequentialAutoBatched ; namun, semua yang kami lakukan di sini berlaku untuk varian lain tanpa perubahan pada dasarnya.

Dependensi & Prasyarat

Impor dan set up

Prasyarat: Masalah Regresi Bayesian

Kami akan mempertimbangkan skenario regresi Bayesian yang sangat sederhana:

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

Dalam model ini, m dan b diambil dari normals standar, dan pengamatan Y diambil dari distribusi normal yang rata-rata tergantung pada variabel-variabel acak m dan b , dan beberapa (nonrandom, dikenal) kovariat X . (Untuk kesederhanaan, dalam contoh ini, kami menganggap skala semua variabel acak diketahui.)

Untuk melakukan inferensi dalam model ini, kami perlu tahu kedua kovariat X dan pengamatan Y , namun untuk tujuan tutorial ini, kita hanya perlu X , jadi kita mendefinisikan dummy sederhana X :

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

Desiderata

Dalam inferensi probabilistik, kita sering ingin melakukan dua operasi dasar:

  • sample : Menggambar sampel dari model.
  • log_prob : Komputasi probabilitas log dari sampel dari model.

Kontribusi kunci dari TFP ini JointDistribution abstraksi (serta dari banyak pendekatan lain untuk pemrograman probabilistik) adalah untuk memungkinkan pengguna untuk menulis sebuah model sekali dan memiliki akses ke kedua sample dan log_prob perhitungan.

Mencatat bahwa kita memiliki 7 poin di set data kami ( X.shape = (7,) ), kita sekarang dapat menyatakan desiderata untuk sangat baik JointDistribution :

  • sample() harus menghasilkan daftar Tensors memiliki bentuk [(), (), (7,) ], sesuai dengan kemiringan skalar, Bias skalar, dan pengamatan vektor, masing-masing.
  • log_prob(sample()) harus menghasilkan skalar: probabilitas log tertentu lereng, bias, dan pengamatan.
  • sample([5, 3]) harus menghasilkan daftar Tensors memiliki bentuk [(5, 3), (5, 3), (5, 3, 7)] , mewakili (5, 3) - batch sampel dari model.
  • log_prob(sample([5, 3])) harus menghasilkan Tensor dengan bentuk (5, 3).

Kita sekarang akan melihat suksesi JointDistribution model, melihat bagaimana untuk mencapai desiderata di atas, dan mudah-mudahan belajar sedikit lebih banyak tentang TFP membentuk sepanjang jalan.

Spoiler alert: Pendekatan yang memenuhi yang desiderata atas tanpa menambahkan boilerplate adalah autobatching .

Percobaan pertama; JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

Ini kurang lebih merupakan terjemahan langsung dari model ke dalam kode. Kemiringan m dan bias b adalah mudah. Y didefinisikan menggunakan sebuah lambda -fungsi: pola umum adalah bahwa lambda -fungsi dari \(k\) argumen dalam JointDistributionSequential (JDS) menggunakan sebelumnya \(k\) distribusi dalam model. Perhatikan urutan "terbalik".

Kami akan memanggil sample_distributions , yang kembali baik sampel dan mendasari "sub-distribusi" yang digunakan untuk menghasilkan sampel. (Kita bisa diproduksi hanya sampel dengan menelepon sample ; nanti di tutorial itu akan mudah untuk memiliki distribusi juga.) Sampel kami memproduksi baik-baik saja:

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

Tapi log_prob menghasilkan hasil dengan bentuk yang tidak diinginkan:

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

Dan beberapa pengambilan sampel tidak berfungsi:

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Mari kita coba memahami apa yang salah.

Ulasan Singkat: Bentuk Batch dan Acara

Dalam TFP, yang biasa (bukan JointDistribution ) distribusi probabilitas memiliki bentuk acara dan bentuk batch, dan memahami perbedaan adalah penting untuk penggunaan efektif TFP:

  • Bentuk acara menggambarkan bentuk undian tunggal dari distribusi; pengundian mungkin tergantung lintas dimensi. Untuk distribusi skalar, bentuk kejadiannya adalah []. Untuk MultivariatNormal 5 dimensi, bentuk kejadiannya adalah [5].
  • Bentuk batch menggambarkan undian yang independen dan tidak terdistribusi secara identik, alias "kumpulan" distribusi. Mewakili sekumpulan distribusi dalam satu objek Python adalah salah satu cara utama TFP mencapai efisiensi dalam skala besar.

Untuk tujuan kita, fakta penting untuk diingat adalah bahwa jika kita sebut log_prob pada sampel tunggal dari distribusi, hasilnya selalu akan memiliki bentuk yang cocok (yaitu, memiliki sebagai dimensi paling kanan) bentuk batch.

Untuk pembahasan lebih mendalam tentang bentuk, lihat yang "Memahami TensorFlow Distribusi Shapes" tutorial .

Mengapa Apakah tidak log_prob(sample()) Menghasilkan skalar?

Mari kita gunakan pengetahuan kita tentang batch dan acara bentuk untuk mengeksplorasi apa yang terjadi dengan log_prob(sample()) . Ini contoh kami lagi:

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

Dan berikut adalah distribusi kami:

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

Probabilitas log dihitung dengan menjumlahkan probabilitas log dari sub-distribusi pada elemen (yang cocok) dari bagian:

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

Jadi, satu tingkat dari penjelasan adalah bahwa perhitungan log probabilitas adalah mengembalikan 7-Tensor karena subkomponen ketiga log_prob_parts adalah 7-Tensor. Tapi kenapa?

Nah, kita melihat bahwa elemen terakhir dari dists , yang sesuai dengan distribusi kami lebih Y dalam perumusan mathematial, memiliki batch_shape dari [7] . Dengan kata lain, distribusi kami lebih Y adalah batch 7 normals independen (dengan cara yang berbeda dan, dalam hal ini, skala yang sama).

Kami sekarang mengerti apa yang salah: di JDS, distribusi lebih Y memiliki batch_shape=[7] , sampel dari JDS merupakan skalar untuk m dan b dan "batch" dari 7 normals independen. dan log_prob menghitung 7 terpisah log-probabilitas, yang masing-masing mewakili probabilitas log menggambar m dan b dan pengamatan tunggal Y[i] di beberapa X[i] .

Memperbaiki log_prob(sample()) dengan Independent

Ingat bahwa dists[2] memiliki event_shape=[] dan batch_shape=[7] :

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

Dengan menggunakan TFP ini Independent metadistribusi, yang mengubah dimensi batch untuk dimensi acara, kita dapat mengkonversi ini menjadi distribusi dengan event_shape=[7] dan batch_shape=[] (kami akan mengganti nama y_dist_i karena distribusi pada Y , dengan _i berdiri untuk kami Independent pembungkus):

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

Sekarang, log_prob dari 7-vektor adalah skalar:

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

Di bawah selimut, Independent jumlah lebih batch:

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

Dan memang, kita dapat menggunakan ini untuk membangun baru jds_i (yang i lagi singkatan Independent ) di mana log_prob mengembalikan skalar:

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

Beberapa catatan:

  • jds_i.log_prob(s) tidak sama dengan tf.reduce_sum(jds.log_prob(s)) . Yang pertama menghasilkan probabilitas log yang "benar" dari distribusi gabungan. Jumlah yang terakhir selama 7-Tensor, setiap elemen yang merupakan jumlah dari probabilitas log m , b , dan satu elemen dari probabilitas log Y , sehingga overcounts m dan b . ( log_prob(m) + log_prob(b) + log_prob(Y) mengembalikan hasilnya daripada membuang pengecualian karena TFP berikut TF dan aturan penyiaran NumPy ini;. Menambahkan skalar untuk vektor menghasilkan hasil vektor berukuran)
  • Dalam kasus ini, kita bisa memecahkan masalah dan mencapai hasil yang sama menggunakan MultivariateNormalDiag bukan Independent(Normal(...)) . MultivariateNormalDiag adalah distribusi vektor-dihargai (yaitu, sudah memiliki vektor-bentuk acara). Indeeed suatu berkat MultivariateNormalDiag bisa (tetapi tidak) dilaksanakan sebagai komposisi Independent dan Normal . Hal ini bermanfaat untuk diingat bahwa diberi vektor V , sampel dari n1 = Normal(loc=V) , dan n2 = MultivariateNormalDiag(loc=V) tidak dapat dibedakan; perbedaan beween distribusi ini adalah bahwa n1.log_prob(n1.sample()) adalah vektor dan n2.log_prob(n2.sample()) adalah skalar.

Beberapa Sampel?

Menggambar banyak sampel masih tidak berfungsi:

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Mari kita pikirkan alasannya. Ketika kita sebut jds_i.sample([5, 3]) , kami akan pertama mengambil contoh untuk m dan b , masing-masing dengan bentuk (5, 3) . Berikutnya, kita akan mencoba untuk membangun sebuah Normal distribusi melalui:

tfd.Normal(loc=m*X + b, scale=1.)

Tetapi jika m memiliki bentuk (5, 3) dan X memiliki bentuk 7 , kita tidak bisa berkembang biak mereka bersama-sama, dan memang ini adalah kesalahan kita memukul sedang:

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Untuk mengatasi masalah ini, mari kita berpikir tentang sifat-sifat apa distribusi lebih Y harus memiliki. Jika kita sudah menelepon jds_i.sample([5, 3]) , maka kita tahu m dan b keduanya akan memiliki bentuk (5, 3) . Apa bentuk harus panggilan untuk sample pada Y menghasilkan distribusi? Jawaban yang jelas adalah (5, 3, 7) : untuk setiap titik batch, kita ingin sampel dengan ukuran yang sama dengan X . Kita dapat mencapainya dengan menggunakan kemampuan penyiaran TensorFlow, menambahkan dimensi ekstra:

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

Menambahkan sumbu kedua m dan b , kita dapat mendefinisikan JDS baru yang mendukung beberapa sampel:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

Sebagai pemeriksaan tambahan, kami akan memverifikasi bahwa probabilitas log untuk satu titik batch cocok dengan yang kami miliki sebelumnya:

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

AutoBatching Untuk Kemenangan

Bagus sekali! Kami sekarang memiliki versi JointDistribution yang menangani semua kami desiderata: log_prob kembali berkat skalar untuk penggunaan tfd.Independent , dan beberapa sampel bekerja sekarang bahwa kita tetap penyiaran dengan menambahkan sumbu ekstra.

Bagaimana jika saya memberi tahu Anda bahwa ada cara yang lebih mudah dan lebih baik? Ada, dan itu disebut JointDistributionSequentialAutoBatched (JDSAB):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

Bagaimana cara kerjanya? Meskipun Anda bisa mencoba untuk membaca kode untuk pemahaman yang mendalam, kami akan memberikan gambaran singkat yang cukup untuk sebagian besar kasus penggunaan:

  • Ingat bahwa masalah pertama kami adalah bahwa distribusi kami untuk Y memiliki batch_shape=[7] dan event_shape=[] , dan kami digunakan Independent untuk mengkonversi dimensi batch untuk dimensi acara. JDSAB mengabaikan bentuk batch dari distribusi komponen; bukannya memperlakukan bentuk batch properti keseluruhan model, yang diasumsikan [] (kecuali ditentukan lain dengan menetapkan batch_ndims > 0 ). Efeknya adalah setara dengan menggunakan tfd.Independent untuk mengkonversi semua dimensi batch distribusi komponen ke dimensi acara, seperti yang kita lakukan secara manual di atas.
  • Masalah kedua kami adalah kebutuhan untuk memijat bentuk m dan b sehingga mereka bisa menyiarkan secara tepat dengan X saat membuat beberapa sampel. Dengan JDSAB, Anda menulis model untuk menghasilkan sampel tunggal, dan kami "mengangkat" seluruh model untuk menghasilkan beberapa sampel menggunakan TensorFlow ini vectorized_map . (Fitur ini analog dengan JAX ini VMAP .)

Menjelajahi masalah bentuk batch yang lebih detail, kita bisa membandingkan bentuk batch kami asli "buruk" distribusi gabungan jds , kami distribusi batch-tetap jds_i dan jds_ia , dan autobatched kami jds_ab :

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

Kita melihat bahwa asli jds memiliki subdistributions dengan bentuk batch yang berbeda. jds_i dan jds_ia memperbaiki ini dengan menciptakan subdistributions dengan (kosong) bentuk batch yang sama. jds_ab hanya memiliki satu (kosong) bentuk batch.

Itu perlu dicatat bahwa JointDistributionSequentialAutoBatched menawarkan beberapa umum tambahan gratis. Misalkan kita membuat kovariat X (dan, secara implisit, pengamatan Y ) dua dimensi:

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

Kami JointDistributionSequentialAutoBatched bekerja dengan tidak ada perubahan (kita perlu mendefinisikan kembali model karena bentuk X -cache oleh jds_ab.log_prob ):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

Di sisi lain, kami hati-hati dibuat JointDistributionSequential tidak lagi bekerja:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

Untuk memperbaiki hal ini, kita harus menambahkan kedua tf.newaxis untuk kedua m dan b sesuai dengan bentuk, dan peningkatan reinterpreted_batch_ndims ke 2 dalam panggilan untuk Independent . Dalam hal ini, membiarkan mesin batch otomatis menangani masalah bentuk lebih pendek, lebih mudah, dan lebih ergonomis.

Sekali lagi, kami mencatat bahwa sementara notebook ini dieksplorasi JointDistributionSequentialAutoBatched , varian lain dari JointDistribution memiliki setara AutoBatched . (Untuk pengguna JointDistributionCoroutine , JointDistributionCoroutineAutoBatched memiliki manfaat tambahan yang Anda tidak perlu lagi untuk menentukan Root node, jika Anda belum pernah menggunakan JointDistributionCoroutine . Anda dapat dengan aman mengabaikan pernyataan ini)

Kesimpulan

Dalam notebook ini, kami memperkenalkan JointDistributionSequentialAutoBatched dan bekerja melalui contoh sederhana secara rinci. Semoga Anda belajar sesuatu tentang bentuk TFP dan tentang autobatching!