tfp.substrates.jax.math.custom_gradient

Embeds a custom gradient into a Tensor.

This function works by clever application of stop_gradient. I.e., observe that:

h(x) = stop_gradient(f(x)) + stop_gradient(g(x)) * (x - stop_gradient(x))

is such that h(x) == stop_gradient(f(x)) and grad[h(x), x] == stop_gradient(g(x)).

In addition to scalar-domain/scalar-range functions, this function also supports tensor-domain/scalar-range functions.

Partial Custom Gradient:

Suppose h(x) = htilde(x, y). Note that dh/dx = stop(g(x)) but dh/dy = None. This is because a Tensor cannot have only a portion of its gradient stopped. To circumvent this issue, one must manually stop_gradient the relevant portions of f, g. For example see the unit-test, test_works_correctly_fx_gx_manually_stopped.

fx Tensor. Output of function evaluated at x.
gx Tensor or list of Tensors. Gradient of function at (each) x.
x Tensor or list of Tensors. Args of evaluation for f.
fx_gx_manually_stopped Python bool indicating that fx, gx manually have stop_gradient applied.
name Python str name prefixed to Ops created by this function.

fx Floating-type Tensor equal to f(x) but which has gradient stop_gradient(g(x)).