|  View source on GitHub | 
Computes diagonal of the Jacobian matrix of ys=fn(xs) wrt xs.
tfp.math.diag_jacobian(
    xs, ys=None, sample_shape=None, fn=None, parallel_iterations=10, name=None
)
If ys is a tensor or a list of tensors of the form (ys_1, .., ys_n) and
xs is of the form (xs_1, .., xs_n), the function jacobians_diag
computes the diagonal of the Jacobian matrix, i.e., the partial derivatives
(dys_1/dxs_1,.., dys_n/dxs_n). For definition details, see
https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant
Example
Diagonal Hessian of the log-density of a 3D Gaussian distribution
In this example we sample from a standard univariate normal
distribution using MALA with step_size equal to 0.75.
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
tfd = tfp.distributions
dtype = np.float32
with tf.Session(graph=tf.Graph()) as sess:
  true_mean = dtype([0, 0, 0])
  true_cov = dtype([[1, 0.25, 0.25], [0.25, 2, 0.25], [0.25, 0.25, 3]])
  chol = tf.linalg.cholesky(true_cov)
  target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)
  # Assume that the state is passed as a list of tensors `x` and `y`.
  # Then the target function is defined as follows:
  def target_fn(x, y):
    # Stack the input tensors together
    z = tf.concat([x, y], axis=-1) - true_mean
    return target.log_prob(z)
  sample_shape = [3, 5]
  state = [tf.ones(sample_shape + [2], dtype=dtype),
           tf.ones(sample_shape + [1], dtype=dtype)]
  fn_val, grads = tfp.math.value_and_gradient(target_fn, state)
  # We can either pass the `sample_shape` of the `state` or not, which impacts
  # computational speed of `diag_jacobian`
  _, diag_jacobian_shape_passed = diag_jacobian(
      xs=state, ys=grads, sample_shape=tf.shape(fn_val))
  _, diag_jacobian_shape_none = diag_jacobian(
      xs=state, ys=grads)
  diag_jacobian_shape_passed_ = sess.run(diag_jacobian_shape_passed)
  diag_jacobian_shape_none_ = sess.run(diag_jacobian_shape_none)
print('hessian computed through `diag_jacobian`, sample_shape passed: ',
      np.concatenate(diag_jacobian_shape_passed_, -1))
print('hessian computed through `diag_jacobian`, sample_shape skipped',
      np.concatenate(diag_jacobian_shape_none_, -1))
| Raises | |
|---|---|
| ValueError | if lists xsandyshave different length or bothysandfnareNone, orfnis None in the eager execution mode. |