הפצות משותפות באופן אוטומטי: מדריך עדין

הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHub הורד מחברת

מבוא

TensorFlow הסתברות (פריון כולל) מציעה מספר JointDistribution ההפשטות שהופכות היסק הסתברותי קל בכך שהוא מאפשר למשתמש בקלות להביע מודל גרפי הסתברותית בצורה כמעט מתמטית; ההפשטה מייצרת שיטות לדגימה מהמודל ולהערכת ההסתברות ביומן של דגימות מהמודל. במדריך זה, אנו בודקים "autobatched" גרסאות, אשר פותחו לאחר המקורי JointDistribution הפשטות. ביחס להפשטות המקוריות, הלא-אוטומטיות, הגרסאות האוטומטיות פשוטות יותר לשימוש וארגונומיות יותר, מה שמאפשר לדגמים רבים לבוא לידי ביטוי עם פחות לוח. בקולאב זה, אנו חוקרים מודל פשוט בפרטי פרטים (אולי מייגעים), מבהירים את הבעיות שהאצט האוטומטי פותר, ו(בתקווה) מלמדים את הקורא יותר על מושגי צורת TFP לאורך הדרך.

לקראת כניסתה של autobatching, היו כמה גרסאות שונות של JointDistribution , המתאימים לסגנונות תחבירי שונה להבעת מודלים הסתברותיים: JointDistributionSequential , JointDistributionNamed , ו JointDistributionCoroutine . Auobatching קיים בתור Mixin, כך שיש לנו עכשיו AutoBatched וריאנטים של כל אלה. במדריך זה, אנו לחקור את ההבדלים בין JointDistributionSequential ו JointDistributionSequentialAutoBatched ; עם זאת, כל מה שאנו עושים כאן ישים על הווריאציות האחרות ללא שינויים למעשה.

תלות ותנאים מוקדמים

ייבוא ​​והגדרות

תנאי מוקדם: בעיית רגרסיה בייסיאנית

נשקול תרחיש רגרסיה בייסיאנית פשוטה מאוד:

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

במודל זה, m ו b שאובים הנורמלים סטנדרטי, ואת התצפיות Y שאובים התפלגות נורמלית אשר מתכוון תלוי במשתנים אקראיים m ו b , וחלקם (אקראי, הידוע) למשתנים X . (למען הפשטות, בדוגמה זו, אנו מניחים שקנה ​​המידה של כל המשתנים האקראיים ידוע).

כדי לבצע היקש במודל זה, היינו צריכים לדעת הן את המשתנים: X ואת התצפיות Y , אבל לצורך מדריך זה, נצטרך רק X , אז אנחנו מגדירים דמה פשוט X :

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

Desiderata

בהסקה הסתברותית, לעתים קרובות אנו רוצים לבצע שתי פעולות בסיסיות:

  • sample : ציור דגימות מהמודל.
  • log_prob : חישוב הסתברות היומן של מדגם מהמודל.

תרומת המפתח של של הפריון הכולל JointDistribution פשטות (כמו גם גישות אחרות רבות לתכנות הסתברותית) היא לאפשר למשתמשים לכתוב מודל אחת לגשת לשני sample ו log_prob חישובים.

וציין כי יש לנו 7 נקודות בערכת הנתונים שלנו ( X.shape = (7,) ), אנו יכולים כעת לקבוע הדבר המבוקש עבור מצוין JointDistribution :

  • sample() צריך לייצר רשימה של Tensors שיש צורה [(), (), (7,) ], המקביל למדרון סקלר, הטיה סקלר, ותצפיות וקטור, בהתאמה.
  • log_prob(sample()) צריך לייצר סקלר: הסתברות היומן של מדרון, הטיה מסוימת, ותצפיות.
  • sample([5, 3]) צריכים לייצר רשימה של Tensors שיש צורה [(5, 3), (5, 3), (5, 3, 7)] , מייצג (5, 3) - קבוצה של דגימות המודל.
  • log_prob(sample([5, 3])) צריך לייצר Tensor עם צורה (5, 3).

כעת אנו נבחנו שורה של JointDistribution מודלים, לראות כיצד להשיג דבר המבוקש לעיל, ואנו מקווים ללמוד קצת יותר על הפריון הכולל צורות לאורך הדרך.

התראה ספוילר: הגישה המספק דבר המבוקש הנ"ל בלא מוכן מראש מוסף autobatching .

נסיון ראשון; JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

זה פחות או יותר תרגום ישיר של המודל לקוד. השיפוע m והטיית b הוא פשוט. Y מוגדר באמצעות lambda -function: הדפוס הכללי הוא כי lambda -function של \(k\) ויכוחים בתוך JointDistributionSequential (JDS) משתמשת הקודם \(k\) הפצות במודל. שימו לב לסדר ה"הפוך".

אנחנו נתקשר sample_distributions , אשר מחזיר הן מדגם ואת "תת-הפצות" שבבסיס ששימשו להפקת המדגם. (אנחנו יכולים הניבו רק המדגם על ידי התקשרות sample ; בהמשך הדרכה זה יהיה נוח לקבל הפצות כמו גם.) המדגם שאנו מייצרים הוא בסדר:

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

אבל log_prob מייצר תוצאה עם צורה רצויה:

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

ודגימה מרובה לא עובדת:

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

בואו ננסה להבין מה לא בסדר.

סקירה קצרה: אצווה וצורת אירוע

בשנת הפריון כוללת, סתם אזרח (לא JointDistribution ) התפלגות הסתברות בעל צורת אירוע צורה אצווה, והבנת ההבדל היא קריטית כדי שימוש יעיל של פריון כולל:

  • צורת אירוע מתארת ​​את הצורה של ציור בודד מההפצה; ההגרלה עשויה להיות תלויה במידות שונות. עבור התפלגויות סקלריות, צורת האירוע היא []. עבור MultivariateNormal 5 ממדי, צורת האירוע היא [5].
  • צורת אצווה מתארת ​​שרטוטים עצמאיים, לא מבוזרים זהים, הלא היא "אצווה" של הפצות. ייצוג אצווה של הפצות באובייקט Python יחיד היא אחת הדרכים המרכזיות שבהן TFP משיג יעילות בקנה מידה.

לענייננו, עובדה קריטית לזכור היא שאם אנחנו קוראים log_prob על מדגם יחיד מחלוק, התוצאה תמיד תהיה צורה כי גפרורים (כלומר, יש גם ממדים ימניים ביותר) בצורה תצווה.

לדיון מעמיק יותר של צורות, לראות את המדריך "TensorFlow הפצות צורות הבנה" .

מדוע לא log_prob(sample()) לייצר סקלר?

בואו להשתמש בידע שלנו בכושר אצווה ואירוע לחקור מה קורה עם log_prob(sample()) . הנה שוב המדגם שלנו:

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

והנה ההפצות שלנו:

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

ההסתברות ביומן מחושבת על ידי סיכום הסתברויות היומן של תת-התפלגויות ברכיבים (המתואמים) של החלקים:

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

אז, אחד ברמה של הסבר היא כי חישוב הסתברות יומן מחזיר 7-טנסור כי רכיב המשנה השלישי של log_prob_parts הוא 7-מותח. אבל למה?

ובכן, אנו רואים כי האלמנט האחרון של dists , אשר תואמת את ההפצה שלנו מעל Y בניסוח mathematial, יש batch_shape של [7] . במילות אחרות, חלוק שלנו מעל Y היא קבוצה של 7 הנורמלים עצמאיים (עם אמצעים שונים, במקרה זה, באותו הסולם).

כעת אנו מבינים מה לא בסדר: ב JDS, חלוק מעל Y יש batch_shape=[7] , מדגם מן JDS מייצג scalars עבור m ו b ו "תצווה" של 7 הנורמלים עצמאיים. ו log_prob מחשב 7 יומן-הסתברויות נפרדת, שכל אחת מהן מייצגת את יומן ההסתברות לשליפת m ו b ו תצפית אחת Y[i] על כמה X[i] .

תיקון log_prob(sample()) עם Independent

נזכיר כי dists[2] יש event_shape=[] ו batch_shape=[7] :

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

באמצעות של הפריון הכולל Independent metadistribution, אשר ממיר ממדים יצוו לממדי אירוע, אנחנו יכולים להמיר את זה לתוך הפצה עם event_shape=[7] ו batch_shape=[] (נצטרך לשנות את זה y_dist_i בגלל שזה הפצה על Y , עם _i העמיד עבור שלנו Independent הגלישה):

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

עכשיו, log_prob של וקטור 7 הוא סקלר:

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

מתחת לשמיכות, Independent סכומים מעל אצווה:

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

ואכן, אנו יכולים להשתמש בזה כדי לבנות חדש jds_i (את i שוב מייצג Independent ) שבו log_prob חוזר סקלר:

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

כמה הערות:

  • jds_i.log_prob(s) הוא לא אותו הדבר כמו tf.reduce_sum(jds.log_prob(s)) . הראשון מייצר את הסתברות היומן ה"נכונה" של ההתפלגות המשותפת. הסכומים האחרונים מעל 7-מותח, כל אלמנט של אשר הוא סכום ההסתברות יומן של m , b , ו אלמנט בודד של הסתברות יומן של Y , אז זה overcounts m ו b . ( log_prob(m) + log_prob(b) + log_prob(Y) מחזירה תוצאה ולא לזרוק חריג כי הפריון הכולל כדלקמן TF ותקנון שידור של numpy;. הוספת סקלר כדי וקטור מייצר תוצאה וקטור בגודל)
  • במקרה הספציפי הזה, היינו יכולים לפתור את הבעיה והשיג אותה תוצאה באמצעות MultivariateNormalDiag במקום Independent(Normal(...)) . MultivariateNormalDiag היא הפצה מוערכת וקטור (כלומר, זה כבר יש וקטור צורת אירוע). Indeeed MultivariateNormalDiag יכול להיות (אך לא) מיושם בתור הרכב Independent ו Normal . כדאי לזכור כי נתון וקטור V , דגימות n1 = Normal(loc=V) , ו n2 = MultivariateNormalDiag(loc=V) הם נבדלים; ההבדל beween הפצות אלה היא כי n1.log_prob(n1.sample()) הוא וקטור n2.log_prob(n2.sample()) הוא סקלר.

מספר דוגמאות?

ציור של מספר דוגמאות עדיין לא עובד:

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

בואו נחשוב למה. כאשר אנו קוראים jds_i.sample([5, 3]) , ראשית ניצור לצייר דגימות m ו b , כל אחד עם צורה (5, 3) . בשלב הבא, אנחנו הולכים לנסות לבנות Normal הפצה באמצעות:

tfd.Normal(loc=m*X + b, scale=1.)

אבל אם m יש צורה (5, 3) ו X יש צורה 7 , אנחנו לא יכולים להכפיל אותם יחד, ואכן זו היא השגיאה אנו להכות במקרים הבאים:

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

כדי לפתור בעיה זו, בואו נחשוב על מה המאפיינים חלוק מעל Y יש לקיים. אם אנחנו כבר נקרא jds_i.sample([5, 3]) , אז אנחנו יודעים m ו b יהיה הן צורה (5, 3) . מה צורה צריכה קריאה sample על Y תוצרת ההפצה? התשובה המתבקשת היא (5, 3, 7) : עבור כל נקודה אצווה, אנחנו רוצים מדגם עם בגודל זהה X . אנו יכולים להשיג זאת על ידי שימוש ביכולות השידור של TensorFlow, הוספת ממדים נוספים:

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

הוספת ציר לשני m ו b , אנו יכולים להגדיר JDS חדש תומך במספר דוגמאות:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

כבדיקה נוספת, נוודא שההסתברות ביומן עבור נקודת אצווה בודדת תואמת למה שהיה לנו קודם:

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

אצווה אוטומטית בשביל הניצחון

מְעוּלֶה! עכשיו יש לנו גרסה של JointDistribution שמטפל בכל דברים רצויים שלנו: log_prob מחזיר תודה סקלר לשימוש tfd.Independent , ודוגמאות רבות לעבוד עכשיו כי אנחנו קבוע לשדר ידי הוספת צירים מיותרים.

מה אם הייתי אומר לך שיש דרך קלה וטובה יותר? יש, וזה נקרא JointDistributionSequentialAutoBatched (JDSAB):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

איך זה עובד? בעוד אתה יכול לנסות לקרוא את הקוד עבור הבנה עמוקה, אנו נספק סקירה קצרה אשר מספקים ברוב מקרי שימוש:

  • נזכיר כי הבעיה הראשונה שלנו היתה כי ההפצה שלנו עבור Y היו batch_shape=[7] ו event_shape=[] , והיינו Independent כדי להמיר את הממד אצווה כדי מימד אירוע. JDSAB מתעלם מצורות האצווה של הפצות רכיבים; במקום זה מתייחס לצורה אצווה כנכס הכוללת של המודל, אשר ההנחה היא להיות [] (אלא אם צוין אחרת על ידי הגדרת batch_ndims > 0 ). ההשפעה כמוהו כשימוש tfd.Independent להמיר את כל הממדים אצווה של הפצות רכיב לממדים האירוע, כפי שעשינו למעלה באופן ידני.
  • שני הבעיה שלנו הייתה צורך לעסות את הצורות של m ו b כדי שיוכלו לשדר כראוי עם X בעת יצירת דגימות מרובות. עם JDSAB, אתה כותב מודל ליצירת מדגם יחיד, ואנחנו "להרים" את המודל כולו כדי ליצור דגימות מרובות באמצעות של TensorFlow vectorized_map . (תכונה זו אנלוגית של JAX vmap .)

היכרות עם הנושא בצורה אצווה ביתר פירוט, אנחנו יכולים להשוות את הצורות אצווה של "רע" המקורי שלנו הפצה משותפת jds , הפצות אצווה קבוע שלנו jds_i ו jds_ia , ו autobatched שלנו jds_ab :

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

אנו רואים כי המקורית jds יש subdistributions עם צורות שונות יצווה. jds_i ו jds_ia לתקן זאת על ידי יצירת subdistributions עם אותה צורה אצווה (ריק). jds_ab יש רק צורה אצווה אחת (ריק).

שווה זה וציין כי JointDistributionSequentialAutoBatched מציעה הכללה נוספת בחינם. נניח שאנחנו עושים את למשתנים X (ו, במשתמע, התצפיות Y ) דו-ממדית:

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

שלנו JointDistributionSequentialAutoBatched עובד ללא שינויים (אנחנו צריכים להגדיר מחדש את המודל בגלל בצורת X במטמון, על ידי jds_ab.log_prob ):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

מצד השני, שלנו מעוצב בקפידה JointDistributionSequential כבר לא עובד:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

כדי לתקן זאת, היינו צריכים להוסיף שנייה tf.newaxis לשני m ו b להתאים את הצורה, ואת הגידול reinterpreted_batch_ndims כדי 2 ב הקריאה Independent . במקרה זה, מתן אפשרות למכונות האצווה האוטומטית לטפל בבעיות הצורה היא קצרה יותר, קלה יותר וארגונומית יותר.

שוב, נציין כי בעוד המחברת הזאת בחנו JointDistributionSequentialAutoBatched , גרסאות אחרות של JointDistribution יש מקבילה AutoBatched . (עבור משתמשי JointDistributionCoroutine , JointDistributionCoroutineAutoBatched יתרון נוסף כי אתה כבר לא צריך לציין Root צמתים; אם מעולם לא השתמשת JointDistributionCoroutine . אתה רשאי להתעלם הצהרה זו)

מחשבות מסכמות

במחברת זו, הצגנו JointDistributionSequentialAutoBatched ועבד באמצעות דוגמה פשוטה בפירוט. אני מקווה שלמדת משהו על צורות TFP ועל אצווה אוטומטית!