tfp.experimental.distribute.make_pbroadcast_function

Constructs a function that broadcasts inputs over named axes.

Given a function fn, make_pbroadcast_function returns a new one that applies pbroadcast to input terms according to axis names provided in in_axes and out_axes. For each output axis in each term out the output of fn, inputs that do not have the output axes present are pbroadcasted before that term is computed.

fn a callable to be transformed to have proadcasts at its inputs.
in_axes A structure of axis names that should match the structure of the input to fn. If the set of input axes for an input value does not match the output axes of a particular output value, the gradient of that output value w.r.t. the input value will be psum-ed over the axes present in the output but not the input.
out_axes A structure of axis names that should match the structure of the output of fn. The inputs to fn will be pbroadcast-ed before computing output terms according to their output axes.
out_dtype A structure of dtypes that matches the output of fn.

A new function that applies pbroadcasts to the inputs of the original function.