View source on GitHub
|
Embeds a custom gradient into a Tensor.
tfp.substrates.numpy.math.custom_gradient(
fx, gx, x, fx_gx_manually_stopped=False, name=None
)
This function works by clever application of stop_gradient. I.e., observe
that:
h(x) = stop_gradient(f(x)) + stop_gradient(g(x)) * (x - stop_gradient(x))
is such that h(x) == stop_gradient(f(x)) and
grad[h(x), x] == stop_gradient(g(x)).
In addition to scalar-domain/scalar-range functions, this function also supports tensor-domain/scalar-range functions.
Partial Custom Gradient:
Suppose h(x) = htilde(x, y). Note that dh/dx = stop(g(x)) but dh/dy =
None. This is because a Tensor cannot have only a portion of its gradient
stopped. To circumvent this issue, one must manually stop_gradient the
relevant portions of f, g. For example see the unit-test,
test_works_correctly_fx_gx_manually_stopped.
Returns | |
|---|---|
fx
|
Floating-type Tensor equal to f(x) but which has gradient
stop_gradient(g(x)).
|
View source on GitHub