Monte-Carlo approximation of an f-Divergence variational loss.
View aliases
Main aliases
tfp.experimental.substrates.jax.vi.monte_carlo_variational_loss
tfp.substrates.jax.vi.monte_carlo_variational_loss(
target_log_prob_fn,
surrogate_posterior,
sample_size=1,
importance_sample_size=1,
discrepancy_fn=tfp.substrates.jax.vi.kl_reverse
,
use_reparameterization=None,
gradient_estimator=None,
stopped_surrogate_posterior=None,
seed=None,
name=None
)
Variational losses measure the divergence between an unnormalized target
distribution p
(provided via target_log_prob_fn
) and a surrogate
distribution q
(provided as surrogate_posterior
). When the
target distribution is an unnormalized posterior from conditioning a model on
data, minimizing the loss with respect to the parameters of
surrogate_posterior
performs approximate posterior inference.
This function defines losses of the form
E_q[discrepancy_fn(log(u))]
, where u = p(z) / q(z)
in the (default) case
where importance_sample_size == 1
, and
u = mean([p(z[k]) / q(z[k]) for k in range(importance_sample_size)]))
more
generally. These losses are sometimes known as f-divergences [1, 2].
The default behavior (discrepancy_fn == tfp.vi.kl_reverse
, where
tfp.vi.kl_reverse = lambda logu: -logu
, and
importance_sample_size == 1
) computes an unbiased estimate of the standard
evidence lower bound (ELBO) [3]. The bound may be tightened by setting
importance_sample_size > 1
[4], and the variance of the estimate reduced by
setting sample_size > 1
. Other discrepancies of interest
available under tfp.vi
include the forward KL[p||q]
, total variation
distance, Amari alpha-divergences, and more.
Args | |
---|---|
target_log_prob_fn
|
Python callable that takes a set of Tensor arguments
and returns a Tensor log-density. Given
q_sample = surrogate_posterior.sample(sample_size) , this
will be called as target_log_prob_fn(*q_sample) if q_sample is a list
or a tuple, target_log_prob_fn(**q_sample) if q_sample is a
dictionary, or target_log_prob_fn(q_sample) if q_sample is a Tensor .
It should support batched evaluation, i.e., should return a result of
shape [sample_size] .
|
surrogate_posterior
|
A tfp.distributions.Distribution
instance defining a variational posterior (could be a
tfd.JointDistribution ). If using tf.Variable parameters, the
distribution's log_prob and (if reparameterizeable) sample methods
must directly invoke all ops that generate gradients to the underlying
variables. One way to ensure this is to use tfp.util.TransformedVariable
and/or tfp.util.DeferredTensor to represent any parameters defined as
transformations of unconstrained variables, so that the transformations
execute at runtime instead of at distribution creation.
|
sample_size
|
Integer scalar number of Monte Carlo samples used to
approximate the variational divergence. Larger values may stabilize
the optimization, but at higher cost per step in time and memory.
Default value: 1 .
|
importance_sample_size
|
Python int number of terms used to define an
importance-weighted divergence. If importance_sample_size > 1 , then the
surrogate_posterior is optimized to function as an importance-sampling
proposal distribution. In this case it often makes sense to use
importance sampling to approximate posterior expectations (see
tfp.vi.fit_surrogate_posterior for an example).
Default value: 1 .
|
discrepancy_fn
|
Python callable representing a Csiszar f function in
in log-space. That is, discrepancy_fn(log(u)) = f(u) , where f is
convex in u .
Default value: tfp.vi.kl_reverse .
|
use_reparameterization
|
Deprecated; use gradient_estimator instead.
|
gradient_estimator
|
Optional element from tfp.vi.GradientEstimators
specifying the stochastic gradient estimator to associate with the
variational loss. If None , a default estimator (either score-function or
reparameterization) is chosen based on
surrogate_posterior.reparameterization_type .
Default value: None .
|
stopped_surrogate_posterior
|
Optional copy of surrogate_posterior with
stopped gradients to the parameters, e.g.,
tfd.Normal(loc=tf.stop_gradient(loc), scale=tf.stop_gradient(scale)) .
Required if and only if
gradient_estimator == tfp.vi.GradientEstimators.DOUBLY_REPARAMETERIZED .
Default value: None .
|
seed
|
PRNG seed for surrogate_posterior.sample ; see
tfp.random.sanitize_seed for details.
|
name
|
Python str name prefixed to Ops created by this function.
|
Returns | |
---|---|
monte_carlo_variational_loss
|
float -like Tensor Monte Carlo
approximation of the Csiszar f-Divergence.
|
Csiszar f-divergences
A Csiszar function f
is a convex function from R^+
(the positive reals)
to R
. The Csiszar f-Divergence is given by:
D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ]
~= m**-1 sum_j^m f( p(x_j) / q(x_j) ),
where x_j ~iid q(X)
For example, f = lambda u: -log(u)
recovers KL[q||p]
, while f =
lambda u: u * log(u)
recovers the forward KL[p||q]
. These and other
functions are available in tfp.vi
.
Example Application:
The Csiszar f-Divergence is a useful framework for variational inference. I.e., observe that,
f(p(x)) = f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] )
<= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ]
:= D_f[p(x, Z), q(Z | x)]
The inequality follows from the fact that the "perspective" of f
, i.e.,
(s, t) |-> t f(s / t))
, is convex in (s, t)
when s/t in domain(f)
and
t
is a real. Since the above framework includes the popular Evidence Lower
BOund (ELBO) as a special case, i.e., f(u) = -log(u)
, we call this framework
"Evidence Divergence Bound Optimization" (EDBO).
References:
[2]: Ali, Syed Mumtaz, and Samuel D. Silvey. "A general class of coefficients of divergence of one distribution from another." Journal of the Royal Statistical Society: Series B (Methodological) 28.1 (1966): 131-142.
[3]: Christopher M. Bishop. Pattern Recognition and Machine Learning. Springer, 2006.
[4] Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance Weighted Autoencoders. In International Conference on Learning Representations, 2016. https://arxiv.org/abs/1509.00519