tfp.experimental.vi.build_asvi_surrogate_posterior

Builds a structured surrogate posterior inspired by conjugate updating.

Used in the notebooks

Used in the tutorials

ASVI, or Automatic Structured Variational Inference, was proposed by Ambrogioni et al. (2020) [1] as a method of automatically constructing a surrogate posterior with the same structure as the prior. It does this by reparameterizing the variational family of the surrogate posterior by structuring each parameter according to the equation

prior_weight * prior_parameter + (1 - prior_weight) * mean_field_parameter

In this equation, prior_parameter is a vector of prior parameters and mean_field_parameter is a vector of trainable parameters with the same domain as prior_parameter. prior_weight is a vector of learnable parameters where 0. <= prior_weight <= 1.. When prior_weight = 0, the surrogate posterior will be a mean-field surrogate, and when prior_weight = 1., the surrogate posterior will be the prior. This convex combination equation, inspired by conjugacy in exponential families, thus allows the surrogate posterior to balance between the structure of the prior and the structure of a mean-field approximation.

prior tfd.JointDistribution instance of the prior.
mean_field Optional Python boolean. If True, creates a degenerate surrogate distribution in which all variables are independent, ignoring the prior dependence structure. Default value: False.
initial_prior_weight Optional float value (either static or tensor value) on the interval [0, 1]. A larger value creates an initial surrogate distribution with more dependence on the prior structure. Default value: 0.5.
name Optional string. Default value: build_asvi_surrogate_posterior.

surrogate_posterior A tfd.JointDistributionCoroutineAutoBatched instance whose samples have shape and structure matching that of prior.

TypeError The prior argument cannot be a nested JointDistribution.

Examples

Consider a Brownian motion model expressed as a JointDistribution:

prior_loc = 0.
innovation_noise = .1

def model_fn():
  new = yield tfd.Normal(loc=prior_loc, scale=innovation_noise)
  for i in range(4):
    new = yield tfd.Normal(loc=new, scale=innovation_noise)

prior = tfd.JointDistributionCoroutineAutoBatched(model_fn)

Let's use variational inference to approximate the posterior. We'll build a surrogate posterior distribution by feeding in the prior distribution.

surrogate_posterior =
  tfp.experimental.vi.build_asvi_surrogate_posterior(prior)

This creates a trainable joint distribution, defined by variables in surrogate_posterior.trainable_variables. We use fit_surrogate_posterior to fit this distribution by minimizing a divergence to the true posterior.

losses = tfp.vi.fit_surrogate_posterior(
  target_log_prob_fn,
  surrogate_posterior=surrogate_posterior,
  num_steps=100,
  optimizer=tf.optimizers.Adam(0.1),
  sample_size=10)

# After optimization, samples from the surrogate will approximate
# samples from the true posterior.
samples = surrogate_posterior.sample(100)
posterior_mean = [tf.reduce_mean(x) for x in samples]
posterior_std = [tf.math.reduce_std(x) for x in samples]

References

[1]: Luca Ambrogioni, Max Hinne, Marcel van Gerven. Automatic structured variational inference. arXiv preprint arXiv:2002.00643, 2020 https://arxiv.org/abs/2002.00643