tfp.substrates.jax.vi.monte_carlo_variational_loss

Monte-Carlo approximation of an f-Divergence variational loss.

Main aliases

tfp.experimental.substrates.jax.vi.monte_carlo_variational_loss

Variational losses measure the divergence between an unnormalized target distribution p (provided via target_log_prob_fn) and a surrogate distribution q (provided as surrogate_posterior). When the target distribution is an unnormalized posterior from conditioning a model on data, minimizing the loss with respect to the parameters of surrogate_posterior performs approximate posterior inference.

This function defines losses of the form E_q[discrepancy_fn(log(u))], where u = p(z) / q(z) in the (default) case where importance_sample_size == 1, and u = mean([p(z[k]) / q(z[k]) for k in range(importance_sample_size)])) more generally. These losses are sometimes known as f-divergences [1, 2].

The default behavior (discrepancy_fn == tfp.vi.kl_reverse, where tfp.vi.kl_reverse = lambda logu: -logu, and importance_sample_size == 1) computes an unbiased estimate of the standard evidence lower bound (ELBO) [3]. The bound may be tightened by setting importance_sample_size > 1 [4], and the variance of the estimate reduced by setting sample_size > 1. Other discrepancies of interest available under tfp.vi include the forward KL[p||q], total variation distance, Amari alpha-divergences, and more.

target_log_prob_fn Python callable that takes a set of Tensor arguments and returns a Tensor log-density. Given q_sample = surrogate_posterior.sample(sample_size), this will be called as target_log_prob_fn(*q_sample) if q_sample is a list or a tuple, target_log_prob_fn(**q_sample) if q_sample is a dictionary, or target_log_prob_fn(q_sample) if q_sample is a Tensor. It should support batched evaluation, i.e., should return a result of shape [sample_size].
surrogate_posterior A tfp.distributions.Distribution instance defining a variational posterior (could be a tfd.JointDistribution). If using tf.Variable parameters, the distribution's log_prob and (if reparameterizeable) sample methods must directly invoke all ops that generate gradients to the underlying variables. One way to ensure this is to use tfp.util.TransformedVariable and/or tfp.util.DeferredTensor to represent any parameters defined as transformations of unconstrained variables, so that the transformations execute at runtime instead of at distribution creation.
sample_size Integer scalar number of Monte Carlo samples used to approximate the variational divergence. Larger values may stabilize the optimization, but at higher cost per step in time and memory. Default value: 1.
importance_sample_size Python int number of terms used to define an importance-weighted divergence. If importance_sample_size > 1, then the surrogate_posterior is optimized to function as an importance-sampling proposal distribution. In this case it often makes sense to use importance sampling to approximate posterior expectations (see tfp.vi.fit_surrogate_posterior for an example). Default value: 1.
discrepancy_fn Python callable representing a Csiszar f function in in log-space. That is, discrepancy_fn(log(u)) = f(u), where f is convex in u. Default value: tfp.vi.kl_reverse.
use_reparameterization Deprecated; use gradient_estimator instead.
gradient_estimator Optional element from tfp.vi.GradientEstimators specifying the stochastic gradient estimator to associate with the variational loss. If None, a default estimator (either score-function or reparameterization) is chosen based on surrogate_posterior.reparameterization_type. Default value: None.
stopped_surrogate_posterior Optional copy of surrogate_posterior with stopped gradients to the parameters, e.g., tfd.Normal(loc=tf.stop_gradient(loc), scale=tf.stop_gradient(scale)). Required if and only if gradient_estimator == tfp.vi.GradientEstimators.DOUBLY_REPARAMETERIZED. Default value: None.
seed PRNG seed for surrogate_posterior.sample; see tfp.random.sanitize_seed for details.
name Python str name prefixed to Ops created by this function.

monte_carlo_variational_loss float-like Tensor Monte Carlo approximation of the Csiszar f-Divergence.

ValueError if surrogate_posterior is not a reparameterized distribution and use_reparameterization = True. A distribution is said to be "reparameterized" when its samples are generated by transforming the samples of another distribution that does not depend on the first distribution's parameters. This property ensures the gradient with respect to parameters is valid.
TypeError if target_log_prob_fn is not a Python callable.

Csiszar f-divergences

A Csiszar function f is a convex function from R^+ (the positive reals) to R. The Csiszar f-Divergence is given by:

D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ]
                ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ),
                           where x_j ~iid q(X)

For example, f = lambda u: -log(u) recovers KL[q||p], while f = lambda u: u * log(u) recovers the forward KL[p||q]. These and other functions are available in tfp.vi.

Example Application:

The Csiszar f-Divergence is a useful framework for variational inference. I.e., observe that,

f(p(x)) =  f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] )
        <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ]
        := D_f[p(x, Z), q(Z | x)]

The inequality follows from the fact that the "perspective" of f, i.e., (s, t) |-> t f(s / t)), is convex in (s, t) when s/t in domain(f) and t is a real. Since the above framework includes the popular Evidence Lower BOund (ELBO) as a special case, i.e., f(u) = -log(u), we call this framework "Evidence Divergence Bound Optimization" (EDBO).

References:

[2]: Ali, Syed Mumtaz, and Samuel D. Silvey. "A general class of coefficients of divergence of one distribution from another." Journal of the Royal Statistical Society: Series B (Methodological) 28.1 (1966): 131-142.

[3]: Christopher M. Bishop. Pattern Recognition and Machine Learning. Springer, 2006.

[4] Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance Weighted Autoencoders. In International Conference on Learning Representations, 2016. https://arxiv.org/abs/1509.00519