tfp.substrates.jax.math.value_and_gradient

Computes f(*xs) and its gradients wrt to *xs.