Google I/O is a wrap! Catch up on TensorFlow sessions View sessions

tfp.substrates.jax.vi.GradientEstimators

Gradient estimators for variational losses.

Variational losses implemented by monte_carlo_variational_loss are defined in general as an expectation of some fn under the surrogate posterior,

loss = expectation(fn, surrogate_posterior)

where the expectation is estimated in practice using a finite sample_size number of samples:

zs = surrogate_posterior.sample(sample_size)
loss_estimate = 1 / sample_size * sum([fn(z) for z in z])

Gradient estimators define a stochastic estimate of the gradient of the above expectation with respect to the parameters of the surrogate posterior.

Members:

  • SCORE_FUNCTION: Also known as REINFORCE [1] or the log-derivative gradient estimator [2]. This estimator works with any surrogate posterior, but gradient estimates may be very noisy.
  • REPARAMETERIZATION: Reparameterization gradients as introduced by Kingma and Welling [3]. These require a continuous-valued surrogate that sets reparameterization_type=FULLY_REPARAMETERIZED (which must implement reparameterized sampling either directly or via implicit reparameterization [4]), and typically yield much lower-variance gradient estimates than the generic score function estimator.
  • DOUBLY_REPARAMETERIZED: The doubly-reparameterized estimator presented by Tucker et al. [5] for importance-weighted bounds. Note that this includes the sticking-the-landing estimator developed by Roeder et al. [6] as a special case when importance_sample_size=1. Compared to 'vanilla' reparameterization, this can provide even lower-variance gradient estimates, but requires a copy of the surrogate posterior with no gradient to its parameters (passed to the loss as stopped_surrogate_posterior), and incurs an additional evaluation of the surrogate density at each step.
  • VIMCO: An extension of the score-function estimator, introduced by Minh and Rezende [7], with reduced variance when importance_sample_size > 1.

References

[1] R. J. Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine Learning, 8(3-4), 229–256, 1992.

[2] Shakir Mohamed. Machine Learning Trick of the Day: Log Derivative Trick. 2015. https://blog.shakirm.com/2015/11/machine-learning-trick-of-the-day-5-log-derivative-trick/

[3] Diederik P. Kingma, and Max Welling. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114, 2013. https://arxiv.org/abs/1312.6114

[4] Michael Figurnov, Shakir Mohamed, and Andriy Mnih. Implicit reparameterization gradients. arXiv preprint arXiv:1805.08498, 2018. https://arxiv.org/abs/1805.08498

[5] George Tucker, Dieterich Lawson, Shixiang Gu, and Chris J. Maddison. Doubly reparameterized gradient estimators for Monte Carlo objectives. arXiv preprint arXiv:1810.04152, 2018. https://arxiv.org/abs/1810.04152

[6] Geoffrey Roeder, Yuhuai Wu, and David Duvenaud. Sticking the landing: Simple, lower-variance gradient estimators for variational inference. arXiv preprint arXiv:1703.09194, 2017. https://arxiv.org/abs/1703.09194

[7] Andriy Mnih and Danilo Rezende. Variational Inference for Monte Carlo objectives. In International Conference on Machine Learning, 2016. https://arxiv.org/abs/1602.06725

DOUBLY_REPARAMETERIZED <GradientEstimators.DOUBLY_REPARAMETERIZED: 2>
REPARAMETERIZATION <GradientEstimators.REPARAMETERIZATION: 1>
SCORE_FUNCTION <GradientEstimators.SCORE_FUNCTION: 0>
VIMCO <GradientEstimators.VIMCO: 3>