Auto-Batched Joint Distributions: A Gentle Tutorial

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Introduction

TensorFlow Probability (TFP) offers a number of JointDistribution abstractions that make probabilistic inference easier by allowing a user to easily express a probabilistic graphical model in a near-mathematical form; the abstraction generates methods for sampling from the model and evaluating the log probability of samples from the model. In this tutorial, we review "autobatched" variants, which were developed after the original JointDistribution abstractions. Relative to the original, non-autobatched abstractions, the autobatched versions are simpler to use and more ergonomic, allowing many models to be expressed with less boilerplate. In this colab, we explore a simple model in (perhaps tedious) detail, making clear the problems autobatching solves, and (hopefully) teaching the reader more about TFP shape concepts along the way.

Prior to the introduction of autobatching, there were a few different variants of JointDistribution, corresponding to different syntactic styles for expressing probabilistic models: JointDistributionSequential, JointDistributionNamed, andJointDistributionCoroutine. Auobatching exists as a mixin, so we now have AutoBatched variants of all of these. In this tutorial, we explore the differences between JointDistributionSequential and JointDistributionSequentialAutoBatched; however, everything we do here is applicable to the other variants with essentially no changes.

Dependencies & Prerequisites

Import and set ups

Prerequisite: A Bayesian Regression Problem

We'll consider a very simple Bayesian regression scenario:

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

In this model, m and b are drawn from standard normals, and the observations Y are drawn from a normal distribution whose mean depends on the random variables m and b, and some (nonrandom, known) covariates X. (For simplicity, in this example, we assume the scale of all random variables is known.)

To perform inference in this model, we'd need to know both the covariates X and the observations Y, but for the purposes of this tutorial, we'll only need X, so we define a simple dummy X:

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

Desiderata

In probabilistic inference, we often want to perform two basic operations:

  • sample: Drawing samples from the model.
  • log_prob: Computing the log probability of a sample from the model.

The key contribution of TFP's JointDistribution abstractions (as well as of many other approaches to probabilistic programming) is to allow users to write a model once and have access to both sample and log_prob computations.

Noting that we have 7 points in our data set (X.shape = (7,)), we can now state the desiderata for an excellent JointDistribution:

  • sample() should produce a list of Tensors having shape [(), (), (7,)], corresponding to the scalar slope, scalar bias, and vector observations, respectively.
  • log_prob(sample()) should produce a scalar: the log probability of a particular slope, bias, and observations.
  • sample([5, 3]) should produce a list of Tensors having shape [(5, 3), (5, 3), (5, 3, 7)], representing a (5, 3)-batch of samples from the model.
  • log_prob(sample([5, 3])) should produce a Tensor with shape (5, 3).

We'll now look at a succession of JointDistribution models, see how to achieve the above desiderata, and hopefully learn a little more about TFP shapes along the way.

Spoiler alert: The approach that satisfies the above desiderata without added boilerplate is autobatching.

First Attempt; JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

This is more or less a direct translation of the model into code. The slope m and bias b are straightforward. Y is defined using a lambda-function: the general pattern is that a lambda-function of \(k\) arguments in a JointDistributionSequential (JDS) uses the previous \(k\) distributions in the model. Note the "reverse" order.

We'll call sample_distributions, which returns both a sample and the underlying "sub-distributions" that were used to generate the sample. (We could have produced just the sample by calling sample; later in the tutorial it will be convenient to have the distributions as well.) The sample we produce is fine:

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

But log_prob produces a result with an undesired shape:

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

And multiple sampling doesn't work:

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Let's try to understand what's going wrong.

A Brief Review: Batch and Event Shape

In TFP, an ordinary (not a JointDistribution) probability distribution has an event shape and a batch shape, and understanding the difference is crucial to effective use of TFP:

  • Event shape describes the shape of a single draw from the distribution; the draw may be dependent across dimensions. For scalar distributions, the event shape is []. For a 5-dimensional MultivariateNormal, the event shape is [5].
  • Batch shape describes independent, not identically distributed draws, aka a "batch" of distributions. Representing a batch of distributions in a single Python object is one of the key ways TFP achieves efficiency at scale.

For our purposes, a critical fact to keep in mind is that if we call log_prob on a single sample from a distribution, the result will always have a shape that matches (i.e., has as rightmost dimensions) the batch shape.

For a more in-depth discussion of shapes, see the "Understanding TensorFlow Distributions Shapes" tutorial.

Why Doesn't log_prob(sample()) Produce a Scalar?

Let's use our knowledge of batch and event shape to explore what's happening with log_prob(sample()). Here's our sample again:

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

And here are our distributions:

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

The log probability is computed by summing the log probabilities of the sub-distributions at the (matched) elements of the parts:

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

So, one level of explanation is that the log probability calculation is returning a 7-Tensor because the third subcomponent of log_prob_parts is a 7-Tensor. But why?

Well, we see that the last element of dists, which corresponds to our distribution over Y in the mathematial formulation, has a batch_shape of [7]. In other words, our distribution over Y is a batch of 7 independent normals (with different means and, in this case, the same scale).

We now understand what's wrong: in JDS, the distribution over Y has batch_shape=[7], a sample from the JDS represents scalars for m and b and a "batch" of 7 independent normals. and log_prob computes 7 separate log-probabilities, each of which represents the log probability of drawing m and b and a single observation Y[i] at some X[i].

Fixing log_prob(sample()) with Independent

Recall that dists[2] has event_shape=[] and batch_shape=[7]:

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

By using TFP's Independent metadistribution, which converts batch dimensions to event dimensions, we can convert this into a distribution with event_shape=[7] and batch_shape=[] (we'll rename it y_dist_i because it's a distribution on Y, with the _i standing in for our Independent wrapping):

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

Now, the log_prob of a 7-vector is a scalar:

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

Under the covers, Independent sums over the batch:

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

And indeed, we can use this to construct a new jds_i (the i again stands for Independent) where log_prob returns a scalar:

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

A couple notes:

  • jds_i.log_prob(s) is not the same as tf.reduce_sum(jds.log_prob(s)). The former produces the "correct" log probability of the joint distribution. The latter sums over a 7-Tensor, each element of which is the sum of the log probability of m, b, and a single element of the log probability of Y, so it overcounts m and b. (log_prob(m) + log_prob(b) + log_prob(Y) returns a result rather than throwing an exception because TFP follows TF and NumPy's broadcasting rules; adding a scalar to a vector produces a vector-sized result.)
  • In this particular case, we could have solved the problem and achieved the same result using MultivariateNormalDiag instead of Independent(Normal(...)). MultivariateNormalDiag is a vector-valued distribution (i.e., it already has vector event-shape). Indeeed MultivariateNormalDiag could be (but isn't) implemented as a composition of Independent and Normal. It's worthwhile to remember that given a vector V, samples from n1 = Normal(loc=V), and n2 = MultivariateNormalDiag(loc=V) are indistinguishable; the difference beween these distributions is that n1.log_prob(n1.sample()) is a vector and n2.log_prob(n2.sample()) is a scalar.

Multiple Samples?

Drawing multiple samples still doesn't work:

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Let's think about why. When we call jds_i.sample([5, 3]), we'll first draw samples for m and b, each with shape (5, 3). Next, we're going to try to construct a Normal distribution via:

tfd.Normal(loc=m*X + b, scale=1.)

But if m has shape (5, 3) and X has shape 7, we can't multiply them together, and indeed this is the error we're hitting:

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

To resolve this issue, let's think about what properties the distribution over Y has to have. If we've called jds_i.sample([5, 3]), then we know m and b will both have shape (5, 3). What shape should a call to sample on the Y distribution produce? The obvious answer is (5, 3, 7): for each batch point, we want a sample with the same size as X. We can achieve this by using TensorFlow's broadcasting capabilities, adding extra dimensions:

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

Adding an axis to both m and b, we can define a new JDS that supports multiple samples:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

As an extra check, we'll verify that the log probability for a single batch point matches what we had before:

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

AutoBatching For The Win

Excellent! We now have a version of JointDistribution that handles all our desiderata: log_prob returns a scalar thanks to the use of tfd.Independent, and multiple samples work now that we fixed broadcasting by adding extra axes.

What if I told you there was an easier, better way? There is, and it's called JointDistributionSequentialAutoBatched (JDSAB):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

How does this work? While you could attempt to read the code for a deep understanding, we'll give a brief overview which is sufficient for most use cases:

  • Recall that our first problem was that our distribution for Y had batch_shape=[7] and event_shape=[], and we used Independent to convert the batch dimension to an event dimension. JDSAB ignores the batch shapes of component distributions; instead it treats batch shape as an overall property of the model, which is assumed to be [] (unless specified otherwise by setting batch_ndims > 0). The effect is equivalent to using tfd.Independent to convert all batch dimensions of component distributions into event dimensions, as we did manually above.
  • Our second problem was a need to massage the shapes of m and b so that they could broadcast appropriately with X when creating multiple samples. With JDSAB, you write a model to generate a single sample, and we "lift" the entire model to generate multiple samples using TensorFlow's vectorized_map. (This feature is analagous to JAX's vmap.)

Exploring the batch shape issue in more detail, we can compare the batch shapes of our original "bad" joint distribution jds, our batch-fixed distributions jds_i and jds_ia, and our autobatched jds_ab:

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

We see that the original jds has subdistributions with different batch shapes. jds_i and jds_ia fix this by creating subdistributions with the same (empty) batch shape. jds_ab has only a single (empty) batch shape.

It's worth noting that JointDistributionSequentialAutoBatched offers some additional generality for free. Suppose we make the covariates X (and, implicitly, the observations Y) two-dimensional:

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

Our JointDistributionSequentialAutoBatched works with no changes (we need to redefine the model because the shape of X is cached by jds_ab.log_prob):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

On the other hand, our carefully crafted JointDistributionSequential no longer works:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

To fix this, we'd have to add a second tf.newaxis to both m and b match the shape, and increase reinterpreted_batch_ndims to 2 in the call to Independent. In this case, letting the auto-batching machinery handle the shape issues is shorter, easier, and more ergonomic.

Once again, we note that while this notebook explored JointDistributionSequentialAutoBatched, the other variants of JointDistribution have equivalent AutoBatched. (For users of JointDistributionCoroutine, JointDistributionCoroutineAutoBatched has the additional benefit that you no longer need to specify Root nodes; if you've never used JointDistributionCoroutine you can safely ignore this statement.)

Concluding Thoughts

In this notebook, we introduced JointDistributionSequentialAutoBatched and worked through a simple example in detail. Hopefully you learned something about TFP shapes and about autobatching!