tfp.experimental.distributions.marginal_fns.tfp_custom_gradient.custom_gradient

Decorates a function and adds custom derivatives.

TF only supports VJPs, so we decorate with tf.custom_gradient.

JAX supports either JVP or VJP. If a custom JVP is provided, then JAX can transpose to derive a VJP rule. Therefore we prefer jvp_fn if given, but fall back to the vjp functions otherwise.

vjp_fwd A function (args) => (output, auxiliaries).
vjp_bwd A function (auxiliaries, output_gradient) => nondiff_args_gradients. None gradients will be inserted into the correct positions for nondiff_argnums.
jvp_fn A function (nondiff_args, primals, tangents) => (primal_out, tangent_out).
nondiff_argnums Tuple of argument indices which are not differentiable. These must integers or other non-Tensors. Tensors with no gradient should be indicated with a None in the result of vjp_bwd.

A decorator to be applied to a function f(*args) => output.