ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


Turns a (potentially nested) structure of dists into a single dist.

structure_of_distributions instance of tfd.Distribution, or nested structure (tuple, list, dict, etc.) in which all leaves are tfd.Distribution instances.
batch_ndims Optional integer Tensor number of leftmost batch dimensions shared across all members of the input structure. If this is specified, the returned joint distribution will be an autobatched distribution with the given batch rank, and all other dimensions absorbed into the event.
validate_args Python bool. Whether the joint distribution should validate input with asserts. This imposes a runtime cost. If validate_args is False, and the inputs are invalid, correct behavior is not guaranteed. Default value: False.

distribution instance of tfd.Distribution such that distribution.sample() is equivalent to tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions). If structure_of_distributions was indeed a structure (as opposed to a single Distribution instance), this will be a JointDistribution with the corresponding structure.

TypeError if any leaves of the input structure are not tfd.Distribution instances.