توزیع های مشترک خودکار: یک آموزش ملایم

مشاهده در TensorFlow.org در Google Colab اجرا شود مشاهده منبع در GitHub دانلود دفترچه یادداشت

معرفی

TensorFlow احتمال (TFP) ارائه می دهد تعدادی از JointDistribution انتزاعی است که استنتاج احتمالی آسان تر با اجازه دادن به کاربر به راحتی بیان یک مدل گرافیکی احتمالاتی در یک فرم نزدیک به ریاضی؛ انتزاع روش هایی را برای نمونه برداری از مدل و ارزیابی احتمال ورود نمونه ها از مدل ایجاد می کند. در این آموزش، ما انواع "autobatched"، که پس از اصلی توسعه داده شد بررسی JointDistribution انتزاعی. نسبت به انتزاعات اصلی و بدون بچ خودکار، استفاده از نسخه‌های اتوبچ ساده‌تر و ارگونومیک‌تر است و به بسیاری از مدل‌ها اجازه می‌دهد با دیگ بخار کمتری بیان شوند. در این مجموعه، ما یک مدل ساده را با جزئیات (شاید خسته‌کننده) بررسی می‌کنیم، مشکلاتی را که بچینگ خودکار حل می‌کند، روشن می‌کنیم و (امیدواریم) مفاهیم شکل TFP را در طول مسیر به خواننده آموزش دهیم.

قبل از معرفی از autobatching شد، چند نوع مختلف از وجود دارد JointDistribution ، مربوط به سبک های مختلف نحوی برای ابراز مدل احتمالاتی: JointDistributionSequential ، JointDistributionNamed و JointDistributionCoroutine . Auobatching به عنوان یک MIXIN وجود دارد، پس ما در حال حاضر AutoBatched انواعی از همه از این. در این آموزش، ما کشف تفاوت های بین JointDistributionSequential و JointDistributionSequentialAutoBatched ؛ با این حال، هر کاری که ما در اینجا انجام می‌دهیم، برای انواع دیگر بدون هیچ تغییری قابل اجرا است.

وابستگی ها و پیش نیازها

واردات و راه اندازی

پیش نیاز: مشکل رگرسیون بیزی

ما یک سناریوی رگرسیون بیزی بسیار ساده را در نظر خواهیم گرفت:

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

در این مدل، m و b از نرمال استاندارد کشیده شده، و مشاهدات Y از یک توزیع نرمال که میانگین بستگی به متغیرهای تصادفی کشیده m و b ، و برخی (غیر تصادفی، شناخته می شود) متغیرهای کمکی X . (برای سادگی، در این مثال، فرض می‌کنیم که مقیاس همه متغیرهای تصادفی مشخص است.)

برای انجام استنتاج در این مدل، ما نیاز به دانستن هر دو متغیر X و مشاهدات Y ، اما برای اهداف این آموزش، ما فقط نیاز X ، بنابراین ساختگی ساده تعریف کنیم X :

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

Desiderata

در استنتاج احتمالی، ما اغلب می خواهیم دو عملیات اساسی را انجام دهیم:

  • sample : طراحی نمونه ها از مدل.
  • log_prob : محاسبات احتمال ورود به سیستم از یک نمونه از مدل.

سهم کلیدی بهره وری کل عوامل را JointDistribution انتزاعی (و همچنین بسیاری از روش های دیگر به برنامه نویسی احتمالی) است که به کاربران اجازه می دهد به ارسال یک مدل یک بار و دسترسی به هر دو sample و log_prob محاسبات.

با توجه به اینکه ما باید 7 امتیاز در مجموعه داده ما ( X.shape = (7,) )، ما هم اکنون می توانید حالت مطلوبی برای عالی JointDistribution :

  • sample() باید یک لیست از تولید Tensors داشتن شکل [(), (), (7,) به ترتیب]، مربوط به شیب اسکالر، تعصب اسکالر، و مشاهدات بردار،.
  • log_prob(sample()) احتمال ورود به سیستم از یک خاص شیب، تعصب، و مشاهدات: باید یک اسکالر تولید کند.
  • sample([5, 3]) باید یک لیست از تولید Tensors داشتن شکل [(5, 3), (5, 3), (5, 3, 7)] ، به نمایندگی از (5, 3) - دسته ای از نمونه ها از مدل.
  • log_prob(sample([5, 3])) باید یک تولید Tensor با شکل (5 و 3).

ما در حال حاضر را در یک توالی از نگاه JointDistribution مدل، ببینید که چگونه برای رسیدن به مطلوبی بالا، و امیدوارم یادگیری کمی بیشتر در مورد بهره وری کل عوامل شکل در طول راه.

هشدار اسپویلر: رویکرد که ارضا مطلوبی بالا بدون تکیهکلامهای اضافه شده است autobatching .

اولین تلاش؛ JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

این کم و بیش ترجمه مستقیم مدل به کد است. شیب m و تعصب b سر راست است. Y تعریف شده است با استفاده از یک lambda تابع: الگوی کلی است که یک lambda تابع از \(k\) استدلال در یک JointDistributionSequential (JDS) با استفاده از قبلی \(k\) توزیع در مدل. به ترتیب "معکوس" توجه کنید.

ما پاسخ sample_distributions ، که بازده هر دو یک نمونه و زمینه ای "زیر توزیع" که برای تولید نمونه استفاده شد. (ما فقط می تواند نمونه با تماس تولید کرده اند sample ؛ بعد از آن در آموزش آن را راحت خواهد بود که توزیع نیز هست.) نمونه ما تولید خوب است:

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

اما log_prob نتیجه با شکل ناخواسته تولید:

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

و نمونه گیری چندگانه کار نمی کند:

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

بیایید سعی کنیم بفهمیم چه چیزی اشتباه می شود.

بررسی مختصر: شکل دسته و رویداد

در بهره وری کل عوامل، یک (نه معمولی JointDistribution ) توزیع احتمال دارای شکل رویداد و یک شکل دسته ای، و درک تفاوت به استفاده موثر از بهره وری کل عوامل بسیار مهم است:

  • شکل رویداد شکل یک قرعه کشی را از توزیع توصیف می کند. قرعه کشی ممکن است به ابعاد مختلف بستگی داشته باشد. برای توزیع های اسکالر، شکل رویداد [] است. برای یک MultivariateNormal 5 بعدی، شکل رویداد [5] است.
  • شکل دسته‌ای ترسیم‌های مستقل و بدون توزیع یکسان را توصیف می‌کند، که به «دسته‌ای» توزیع‌ها معروف است. نمایش دسته ای از توزیع ها در یک شی پایتون یکی از راه های کلیدی TFP برای دستیابی به کارایی در مقیاس است.

برای اهداف ما، یک واقعیت مهم را به خاطر داشته باشید این است که اگر ما پاسخ log_prob تنها یک نمونه از توزیع، نتیجه این خواهد همیشه یک شکل که مسابقات (به عنوان مثال، به عنوان ابعاد سمت راست است) شکل دسته ای داشته باشد.

برای بحث بیشتر در عمق از اشکال، و «شناخت TensorFlow توزیع اشکال" آموزش .

چرا log_prob(sample()) تولید یک اسکالر؟

اجازه دهید با استفاده دانش ما از دسته ای و رویداد شکل برای کشف آنچه که با اتفاق می افتد log_prob(sample()) . نمونه ما دوباره اینجاست:

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

و در اینجا توزیع های ما وجود دارد:

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

احتمال گزارش با جمع کردن احتمالات گزارش توزیع‌های فرعی در عناصر (همسان) قطعات محاسبه می‌شود:

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

بنابراین، یک سطح از توضیح این است که محاسبه احتمال ورود به سیستم در حال بازگشت 7 تانسور زیرا فرعی سوم log_prob_parts یک تانسور 7 است. اما چرا؟

خب، ما می بینیم که آخرین عنصر dists ، که مربوط به توزیع ما بیش از Y در تدوین mathematial، یک batch_shape از [7] . به عبارت دیگر، توزیع ما بیش از Y دسته ای از 7 نرمال مستقل (با معنای مختلف و، در این مورد، با همان مقیاس).

ما در حال حاضر درک چه چیزی اشتباه: در JDS، توزیع بیش از Y است batch_shape=[7] ، یک نمونه از JDS نشان دهنده بندی برای m و b و یک "دسته" از 7 نرمال مستقل است. و log_prob محاسبه 7 جدا ورود احتمالات، هر کدام نشان دهنده احتمال ورود به سیستم از رسم m و b و مشاهده تک Y[i] در برخی از X[i] .

رفع log_prob(sample()) با Independent

به یاد بیاورید که dists[2] است event_shape=[] و batch_shape=[7] :

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

با استفاده از بهره وری کل عوامل را Independent برتر میگوییم، که ابعاد دسته ای به ابعاد رویداد تبدیل، ما می توانیم این را به یک توزیع با تبدیل event_shape=[7] و batch_shape=[] (ما آن را تغییر نام دهید y_dist_i این دلیل که یک توزیع روی است Y ، با _i ایستاده در برای ما Independent بسته بندی):

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

در حال حاضر، log_prob یک بردار 7 یک اسکالر باشد:

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

تحت پوشش می دهد، Independent مبالغ بیش از دسته:

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

و در واقع، ما می توانیم این را به ساخت یک جدید استفاده کنید jds_i (به i دوباره برای مخفف Independent ) که در آن log_prob یک اسکالر می گرداند:

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

چند نکته:

  • jds_i.log_prob(s) است همان نیست tf.reduce_sum(jds.log_prob(s)) . اولی احتمال لاگ صحیح توزیع مشترک را تولید می کند. مبالغ دوم بیش از یک تانسور 7، هر عنصر از که در مجموع از احتمال ورود به سیستم از است m ، b ، و یک عنصر از احتمال ورود به سیستم از Y ، پس از آن overcounts m و b . ( log_prob(m) + log_prob(b) + log_prob(Y) در نتیجه به جای پرتاب یک استثنا به دلیل بهره وری کل عوامل زیر TF و قوانین پخش نامپای را برمی گرداند. اضافه کردن یک اسکالر به یک بردار تولید نتیجه بردار به اندازه)
  • در این مورد خاص، ما می تواند مشکل حل و نتیجه را با استفاده از دست MultivariateNormalDiag جای Independent(Normal(...)) . MultivariateNormalDiag یک توزیع برداری مقدار است (یعنی آن را در حال حاضر دارای بردار رویداد شکل). Indeeed MultivariateNormalDiag می تواند (اما نه) اجرا به عنوان یک ترکیب Independent و Normal . آن را ارزشمند به یاد داشته باشید که با توجه به بردار V ، نمونه ها از n1 = Normal(loc=V) و n2 = MultivariateNormalDiag(loc=V) غیر قابل تشخیص هستند. تفاوت beween این توزیع است که n1.log_prob(n1.sample()) است یک بردار و n2.log_prob(n2.sample()) یک اسکالر است.

چند نمونه؟

کشیدن چندین نمونه هنوز کار نمی کند:

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

بیایید به این فکر کنیم که چرا. هنگامی که ما در پاسخ jds_i.sample([5, 3]) ، ما اولین نمونه برای جلب m و b ، هر کدام با شکل (5, 3) . بعد، ما قصد داریم به تلاش برای ساخت یک Normal توزیع شده از طریق:

tfd.Normal(loc=m*X + b, scale=1.)

اما اگر m است شکل (5, 3) و X شکل 7 ، ما نمی توانیم آنها را با هم ضرب، و در واقع این خطا ما ضربه هستید:

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

برای حل این مسئله، اجازه دهید در مورد آنچه از خواص توزیع بیش از فکر می کنم Y است به. اگر ما به نام ام jds_i.sample([5, 3]) ، پس ما می دانیم m و b هر دو شکل باید (5, 3) . به چه شکل باید یک تماس به sample در Y تولید توزیع؟ پاسخ واضح است (5, 3, 7) : برای هر نقطه دسته ای، ما می خواهیم یک نمونه با همان اندازه به عنوان X . ما می‌توانیم با استفاده از قابلیت‌های پخش TensorFlow و افزودن ابعاد اضافی به این مهم دست یابیم:

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

اضافه کردن یک محور به هر دو m و b ، ما می توانیم یک JDS جدید را تعریف است که پشتیبانی از نمونه های متعدد:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

به عنوان یک بررسی اضافی، بررسی می‌کنیم که احتمال گزارش برای یک نقطه دسته منفرد با آنچه قبلاً داشتیم مطابقت دارد:

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

بچینگ خودکار برای پیروزی

عالی! ما در حال حاضر یک نسخه از JointDistribution دارند که دسته همه ما مطلوبی: log_prob بازده لطف اسکالر به استفاده از tfd.Independent ، و نمونه های متعدد کار می کنند که ما ثابت پخش با اضافه کردن محور اضافی.

اگر به شما بگویم راه ساده تر و بهتری وجود دارد چه؟ است وجود دارد، و آن را به نام JointDistributionSequentialAutoBatched (JDSAB):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

این چطوری کار میکنه؟ در حالی که شما می تواند تلاش برای خواندن کد برای درک عمیق، ما یک مرور کوتاه است که برای بسیاری از موارد استفاده کافی را:

  • به یاد بیاورید که اولین مشکل ما این بود که توزیع ما برای Y حال batch_shape=[7] و event_shape=[] ، و ما استفاده Independent برای تبدیل دسته ای بعد به بعد رویداد. JDSAB اشکال دسته ای توزیع اجزا را نادیده می گیرد. به جای آن رفتار شکل دسته ای به عنوان یک اموال کلی این مدل، که فرض می شود [] (مگر اینکه در غیر این صورت با تنظیم مشخص batch_ndims > 0 ). اثر معادل با استفاده از tfd.Independent برای تبدیل تمام ابعاد دسته ای از توزیع جزء به ابعاد رویداد، به عنوان ما به صورت دستی بالا انجام داد.
  • مشکل دوم ما نیاز به ماساژ اشکال بود m و b به طوری که آنها می تواند مناسب با پخش X در هنگام ایجاد نمونه های متعدد. با JDSAB، شما ارسال یک مدل برای تولید تنها یک نمونه، و ما "بلند" کل مدل برای تولید نمونه های متعدد با استفاده TensorFlow است vectorized_map . (این قابلیت مثابه به JAX است vmap .)

بررسی مسئله شکل دسته ای در جزئیات بیشتر، ما می توانیم شکل دسته ای از اصلی "بد" ما توزیع مشترک مقایسه jds ، دسته ای ثابت ما توزیع jds_i و jds_ia و autobatched ما jds_ab :

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

ما می بینیم که اصلی jds است subdistributions با اشکال مختلف دسته ای. jds_i و jds_ia این تعمیر با ایجاد subdistributions با همان (خالی) شکل دسته ای. jds_ab است تنها یک (خالی) شکل دسته ای.

ارزش آن را اشاره کرد که JointDistributionSequentialAutoBatched ارائه می دهد برخی کلیت اضافی به صورت رایگان. فرض کنید ما را متغیرهای کمکی X (و به طور ضمنی، مشاهدات Y ) دو بعدی:

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

ما JointDistributionSequentialAutoBatched بدون هیچ تغییری (ما نیاز به تعریف مجدد مدل به دلیل شکل کار می کند X است کش jds_ab.log_prob ):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

از سوی دیگر، ما با دقت گردد JointDistributionSequential کار می کند دیگر:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

برای حل این مشکل، ما می خواهم که برای اضافه کردن یک دوم tf.newaxis به هر دو m و b مطابقت با شکل و افزایش reinterpreted_batch_ndims تا 2 در پاسخ به Independent . در این مورد، اجازه دادن به ماشین‌های بچینگ خودکار برای رسیدگی به مشکلات شکل کوتاه‌تر، آسان‌تر و ارگونومیک‌تر است.

یک بار دیگر، توجه داشته باشید که در حالی که این نوت بوک مورد بررسی JointDistributionSequentialAutoBatched ، انواع دیگری از JointDistribution دارند معادل AutoBatched . (برای کاربران JointDistributionCoroutine ، JointDistributionCoroutineAutoBatched تا به سود بیشتری که شما دیگر نیازی به مشخص Root گره؛ اگر شما استفاده می شود هرگز JointDistributionCoroutine . شما با خیال راحت می تواند این بیانیه چشم پوشی)

اندیشه های پایانی

در این نوت بوک، ما معرفی JointDistributionSequentialAutoBatched و از طریق یک مثال ساده در جزئیات کار کرده است. امیدواریم چیزی در مورد اشکال TFP و در مورد بچینگ خودکار یاد گرفته باشید!