View source on GitHub |
Computes f(*args)
and its gradients wrt to *args
.
tfp.substrates.jax.math.value_and_gradient(
f,
*args,
output_gradients=None,
use_gradient_tape=False,
auto_unpack_single_arg=True,
has_aux=False,
name=None,
**kwargs
)