Constructs a function that broadcasts inputs over named axes.
tfp.experimental.distribute.make_psum_function(
fn, in_axes, out_axes, out_dtype
)
Given a function fn
, make_psum_function
returns a new one that
includes psums over terms according to axis names provided in out_axes
. It
also adds psums for the vector-Jacobian product of the outputs of fn
w.r.t.
its inputs according to in_axes
if there are axes in the outputs that are
not present in an input.
Args |
fn
|
a callable to be transformed to have psums at its outputs and on the
gradients to its inputs.
|
in_axes
|
A structure of axis names that should match the structure of the
input to fn . If the set of input axes for an input value does not match
the output axes of a particular output value, the gradient of that output
value w.r.t. the input value will be psum-ed over the axes present in the
output but not the input.
|
out_axes
|
A structure of axis names that should match the structure of the
output of fn . The outputs of fn will be psum-med according to their
respective output axes.
|
out_dtype
|
A structure of dtypes that matches the output of fn .
|
Returns |
A new function that applies psums on to the output of the original
function and corrects the gradient with respect to its inputs.
|