tfp.experimental.distribute.make_sharded_log_prob_parts

Constructs a log prob parts function that all-reduces over terms.

Given a log_prob_parts function, this function will return a new one that includes all-reduce sums over terms according to the is_sharded property. It will also add all-reduce sums for the gradient of sharded terms w.r.t. unsharded terms.

log_prob_parts_fn a callable that takes in a structured value and returns a structure of log densities for each of the terms, that when summed returns a locally correct log-density.
axis_names a structure of values that matches the input and output of log_prob_parts_fn. Each value in axis_names is either None, a string name of a mapped axis in the JAX backend or any non-Nonevalue in TF backend, or an iterable thereof corresponding to multiple sharding axes. If theaxis_nameis notNone, the returned function will add all-reduce sum(s) for its term in the log prob calculation. If it isNone`, the returned function will have an all-reduce sum over the gradient of sharded terms w.r.t. to the unsharded value.

A new log prob parts function that can be run inside of a strategy.