Module: tfp.experimental.distribute

Experimental module for doing distributed log prob calculations.

Classes

class JointDistributionCoroutine: A sharding-aware JointDistributionCoroutine.

class JointDistributionNamed: A sharding-aware JointDistributionNamed.

class JointDistributionSequential: A sharding-aware JointDistributionSequential.

class ShardedIndependent: A version of tfd.Independent that folds device id into its randomness.

class ShardedSample: A version of tfd.Sample that shards its output across devices.

Functions

make_sharded_log_prob_parts(...): Constructs a log prob parts function that all-reduces over terms.