tfp.vi.fit_surrogate_posterior_stateless

Fit a surrogate posterior to a target (unnormalized) log density.

The default behavior constructs and minimizes the negative variational evidence lower bound (ELBO), given by

q_samples = surrogate_posterior.sample(num_draws)
elbo_loss = -tf.reduce_mean(
  target_log_prob_fn(q_samples) - surrogate_posterior.log_prob(q_samples))

This corresponds to minimizing the 'reverse' Kullback-Liebler divergence (KL[q||p]) between the variational distribution and the unnormalized target_log_prob_fn, and defines a lower bound on the marginal log likelihood, log p(x) >= -elbo_loss. [1]

More generally, this function supports fitting variational distributions that minimize any Csiszar f-divergence.

target_log_prob_fn Python callable that takes a set of Tensor arguments and returns a Tensor log-density. Given q_sample = surrogate_posterior.sample(sample_size), this will be called as target_log_prob_fn(*q_sample) if q_sample is a list or a tuple, target_log_prob_fn(**q_sample) if q_sample is a dictionary, or target_log_prob_fn(q_sample) if q_sample is a Tensor. It should support batched evaluation, i.e., should return a result of shape [sample_size].
build_surrogate_posterior_fn Python callable that takes parameter values and returns an instance of tfd.Distribution.
initial_parameters List or tuple of initial parameter values (Tensors or structures of Tensors), passed as positional arguments to build_surrogate_posterior_fn.
optimizer Pure functional optimizer to use. This may be an optax.GradientTransformation instance (in JAX), or any similar object that implements methods optimizer_state = optimizer.init(parameters) and updates, optimizer_state = optimizer.update(grads, optimizer_state, parameters).
num_steps Python int number of steps to run the optimizer.
convergence_criterion Optional instance of tfp.optimizer.convergence_criteria.ConvergenceCriterion representing a criterion for detecting convergence. If None, the optimization will run for num_steps steps, otherwise, it will run for at most num_steps steps, as determined by the provided criterion. Default value: None.
trace_fn Python callable with signature traced_values = trace_fn( traceable_quantities), where the argument is an instance of tfp.math.MinimizeTraceableQuantities and the returned traced_values may be a Tensor or nested structure of Tensors. The traced values are stacked across steps and returned. The default trace_fn simply returns the loss. In general, trace functions may also examine the gradients, values of parameters, the state propagated by the specified convergence_criterion, if any (if no convergence criterion is specified, this will be None). Default value: lambda traceable_quantities: traceable_quantities.loss.
discrepancy_fn Python callable representing a Csiszar f function in in log-space. See the docs for tfp.vi.monte_carlo_variational_loss for examples. Default value: tfp.vi.kl_reverse.
sample_size Python int number of Monte Carlo samples to use in estimating the variational divergence. Larger values may stabilize the optimization, but at higher cost per step in time and memory. Default value: 1.
importance_sample_size Python int number of terms used to define an importance-weighted divergence. If importance_sample_size > 1, then the surrogate_posterior is optimized to function as an importance-sampling proposal distribution. In this case, posterior expectations should be approximated by importance sampling, as demonstrated in the example below. Default value: 1.
gradient_estimator Optional element from tfp.vi.GradientEstimators specifying the stochastic gradient estimator to associate with the variational loss. Default value: csiszar_divergence.GradientEstimators.REPARAMETERIZATION.
jit_compile If True, compiles the loss function and gradient update using XLA. XLA performs compiler optimizations, such as fusion, and attempts to emit more efficient code. This may drastically improve the performance. See the docs for tf.function. (In JAX, this will apply jax.jit). Default value: False.
seed PRNG seed; see tfp.random.sanitize_seed for details.
name Python str name prefixed to ops created by this function. Default value: 'fit_surrogate_posterior'.

optimized_parameters Tuple of optimized parameter values, with the same structure and Tensor shapes as initial_parameters.
results Tensor or nested structure of Tensors, according to the return type of trace_fn. Each Tensor has an added leading dimension of size num_steps, packing the trajectory of the result over the course of the optimization.

Examples

Normal-Normal model. We'll first consider a simple model z ~ N(0, 1), x ~ N(z, 1), where we suppose we are interested in the posterior p(z | x=5):

import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd

def log_prob(z, x):
  return tfd.Normal(0., 1.).log_prob(z) + tfd.Normal(z, 1.).log_prob(x)
conditioned_log_prob = lambda z: log_prob(z, x=5.)

The posterior is itself normal by conjugacy, and can be computed analytically (it's N(loc=5/2., scale=1/sqrt(2)). But suppose we don't want to bother doing the math: we can use variational inference instead!

import optax  # Requires JAX backend.
init_normal, build_normal = tfp.experimental.util.make_trainable_stateless(
  tfd.Normal, name='q_z')
optimized_parameters, losses = tfp.vi.fit_surrogate_posterior_stateless(
    conditioned_log_prob,
    build_surrogate_posterior_fn=build_normal,
    initial_parameters=init_normal(seed=(42, 42)),
    optimizer=optax.adam(learning_rate=0.1),
    num_steps=100,
    seed=(42, 42))
q_z = build_normal(*optimized_parameters)

Custom loss function. Suppose we prefer to fit the same model using the forward KL divergence KL[p||q]. We can pass a custom discrepancy function:

optimized_parameters, losses = tfp.vi.fit_surrogate_posterior_stateless(
    conditioned_log_prob,
    build_surrogate_posterior_fn=build_normal,
    initial_parameters=init_normal(seed=(42, 42)),
    optimizer=optax.adam(learning_rate=0.1),
    num_steps=100,
    seed=(42, 42),
    discrepancy_fn=tfp.vi.kl_forward)
q_z = build_normal(*optimized_parameters)

Note that in practice this may have substantially higher-variance gradients than the reverse KL.

Importance weighting. A surrogate posterior may be corrected by interpreting it as a proposal for an importance sampler. That is, one can use weighted samples from the surrogate to estimate expectations under the true posterior:

zs, q_log_prob = surrogate_posterior.experimental_sample_and_log_prob(
  num_samples, seed=(42, 42))

# Naive expectation under the surrogate posterior.
expected_x = tf.reduce_mean(f(zs), axis=0)

# Importance-weighted estimate of the expectation under the true posterior.
self_normalized_log_weights = tf.nn.log_softmax(
  target_log_prob_fn(zs) - q_log_prob)
expected_x = tf.reduce_sum(
  tf.exp(self_normalized_log_weights) * f(zs),
  axis=0)

Any distribution may be used as a proposal, but it is often natural to consider surrogates that were themselves fit by optimizing an importance-weighted variational objective [2], which directly optimizes the surrogate's effectiveness as an proposal distribution. This may be specified by passing importance_sample_size > 1. The importance-weighted objective may favor different characteristics than the original objective. For example, effective proposals are generally overdispersed, whereas a surrogate optimizing reverse KL would otherwise tend to be underdispersed.

Although importance sampling is guaranteed to tighten the variational bound, some research has found that this does not necessarily improve the quality of deep generative models, because it also introduces gradient noise that can lead to a weaker training signal [3]. As always, evaluation is important to choose the approach that works best for a particular task.

When using an importance-weighted loss to fit a surrogate, it is also recommended to apply importance sampling when computing expectations under that surrogate.

# Fit `q` with an importance-weighted variational loss.
optimized_parameters, losses = tfp.vi.fit_surrogate_posterior_stateless(
      conditioned_log_prob,
      build_surrogate_posterior_fn=build_normal,
      initial_parameters=init_normal(seed=(42, 42)),
      importance_sample_size=10,
      optimizer=optax.adam(0.1),
      num_steps=200,
      seed=(42, 42))
q_z = build_normal(*optimized_parameters)

# Estimate posterior statistics with importance sampling.
zs, q_log_prob = q_z.experimental_sample_and_log_prob(1000, seed=(42, 42))
self_normalized_log_weights = tf.nn.log_softmax(
  conditioned_log_prob(zs) - q_log_prob)
posterior_mean = tf.reduce_sum(
  tf.exp(self_normalized_log_weights) * zs,
  axis=0)
posterior_variance = tf.reduce_sum(
  tf.exp(self_normalized_log_weights) * (zs - posterior_mean)**2,
  axis=0)

Inhomogeneous Poisson Process. For a more interesting example, let's consider a model with multiple latent variables as well as trainable parameters in the model itself. Given observed counts y from spatial locations X, consider an inhomogeneous Poisson process model log_rates = GaussianProcess(index_points=X); y = Poisson(exp(log_rates)) in which the latent (log) rates are spatially correlated following a Gaussian process:

# Toy 1D data.
index_points = np.array([-10., -7.2, -4., -0.1, 0.1, 4., 6.2, 9.]).reshape(
    [-1, 1]).astype(np.float32)
observed_counts = np.array(
    [100, 90, 60, 13, 18, 37, 55, 42]).astype(np.float32)

# Generative model.
def model_fn():
  kernel_amplitude = yield tfd.LogNormal(
      loc=0., scale=1., name='kernel_amplitude')
  kernel_lengthscale = yield tfd.LogNormal(
      loc=0., scale=1., name='kernel_lengthscale')
  observation_noise_scale = yield tfd.LogNormal(
      loc=0., scale=1., name='observation_noise_scale')
  kernel = tfp.math.psd_kernels.ExponentiatedQuadratic(
      amplitude=kernel_amplitude,
      length_scale=kernel_lengthscale)
  latent_log_rates = yield tfd.GaussianProcess(
      kernel,
      index_points=index_points,
      observation_noise_variance=observation_noise_scale,
      name='latent_log_rates')
  y = yield tfd.Independent(tfd.Poisson(log_rate=latent_log_rates),
                            reinterpreted_batch_ndims=1,
                            name='y')
model = tfd.JointDistributionCoroutineAutoBatched(model_fn)
pinned = model.experimental_pin(y=observed_counts)

Next we define a variational family. This is represented statelessly as a build_surrogate_posterior_fn from raw (unconstrained) parameters to a surrogate posterior distribution. Note that common variational families can be constructed automatically using the utilities in tfp.experimental.vi; here we demonstrate a manual approach.


initial_parameters = (0., 0., 0.,  # Raw kernel parameters.
                      tf.zeros_like(observed_counts),  # `logit_locs`
                      tf.zeros_like(observed_counts))  # `logit_raw_scales`

def build_surrogate_posterior_fn(
  raw_kernel_amplitude, raw_kernel_lengthscale, raw_observation_noise_scale,
  logit_locs, logit_raw_scales):

  def variational_model_fn():
    # Fit the kernel parameters as point masses.
    yield tfd.Deterministic(
        tf.nn.softplus(raw_kernel_amplitude), name='kernel_amplitude')
    yield tfd.Deterministic(
        tf.nn.softplus(raw_kernel_lengthscale), name='kernel_lengthscale')
    yield tfd.Deterministic(
        tf.nn.softplus(raw_observation_noise_scale),
        name='kernel_observation_noise_scale')
    # Factored normal posterior over the GP logits.
    yield tfd.Independent(
        tfd.Normal(loc=logit_locs,
                   scale=tf.nn.softplus(logit_raw_scales)),
        reinterpreted_batch_ndims=1,
        name='latent_log_rates')
  return tfd.JointDistributionCoroutineAutoBatched(variational_model_fn)

Finally, we fit the variational posterior and model variables jointly. We'll use a custom trace_fn to see how the kernel amplitudes and a set of sampled latent rates with fixed seed evolve during the course of the optimization:

[
    optimized_parameters,
    (losses, amplitude_path, sample_path)
] = tfp.vi.fit_surrogate_posterior_stateless(
    target_log_prob_fn=pinned.unnormalized_log_prob,
    build_surrogate_posterior_fn=build_surrogate_posterior_fn,
    initial_parameters=initial_parameters,
    optimizer=optax.adam(learning_rate=0.1),
    sample_size=1,
    num_steps=500,
    trace_fn=lambda traceable_quantities: (  
          traceable_quantities.loss,
          tf.nn.softplus(traceable_quantities.parameters[0]),
          build_surrogate_posterior_fn(
              *traceable_quantities.parameters).sample(
              5, seed=(42, 42))[-1]),
    seed=(42, 42))
surrogate_posterior = build_surrogate_posterior_fn(*optimized_parameters)

References

[1]: Christopher M. Bishop. Pattern Recognition and Machine Learning. Springer, 2006.

[2] Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance Weighted Autoencoders. In International Conference on Learning Representations, 2016. https://arxiv.org/abs/1509.00519

[3] Tom Rainforth, Adam R. Kosiorek, Tuan Anh Le, Chris J. Maddison, Maximilian Igl, Frank Wood, and Yee Whye Teh. Tighter Variational Bounds are Not Necessarily Better. In International Conference on Machine Learning (ICML), 2018. https://arxiv.org/abs/1802.04537