ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

Module: tfp.experimental.distribute

Experimental module for doing distributed log prob calculations.

Classes

class JointDistributionCoroutine: A sharding-aware JointDistributionCoroutine.

class JointDistributionNamed: A sharding-aware JointDistributionNamed.

class JointDistributionSequential: A sharding-aware JointDistributionSequential.

class Sharded: A meta-distribution meant for use in an SPMD distributed context.

Functions

make_pbroadcast_function(...): Constructs a function that broadcasts inputs over named axes.

make_psum_function(...): Constructs a function that broadcasts inputs over named axes.

make_sharded_log_prob_parts(...): Constructs a log prob parts function that all-reduces over terms.