tfp.experimental.vi.build_factored_surrogate_posterior_stateless

Builds a joint variational posterior that factors over model variables.

By default, this method creates an independent trainable Normal distribution for each variable, transformed using a bijector (if provided) to match the support of that variable. This makes extremely strong assumptions about the posterior: that it is approximately normal (or transformed normal), and that all model variables are independent.

event_shape Tensor shape, or nested structure of Tensor shapes, specifying the event shape(s) of the posterior variables.
bijector Optional tfb.Bijector instance, or nested structure of such instances, defining support(s) of the posterior variables. The structure must match that of event_shape and may contain None values. A posterior variable will be modeled as tfd.TransformedDistribution(underlying_dist, bijector) if a corresponding constraining bijector is specified, otherwise it is modeled as supported on the unconstrained real line.
batch_shape The batch_shape of the output distribution. Default value: ().
base_distribution_cls Subclass of tfd.Distribution that is instantiated and optionally transformed by the bijector to define the component distributions. May optionally be a structure of such subclasses matching event_shape. Default value: tfd.Normal.
initial_parameters Optional str : Tensor dictionary specifying initial values for some or all of the base distribution's trainable parameters, or a Python callable with signature value = parameter_init_fn(parameter_name, shape, dtype, seed, constraining_bijector), passed to tfp.experimental.util.make_trainable. May optionally be a structure matching event_shape of such dictionaries and/or callables. Dictionary entries that do not correspond to parameter names are ignored. Default value: {'scale': 1e-2} (ignored when base_distribution does not have a scale parameter).
dtype Optional float dtype for trainable parameters. May optionally be a structure of such dtypes matching event_shape. Default value: tf.float32.
validate_args Python bool. Whether to validate input with asserts. This imposes a runtime cost. If validate_args is False, and the inputs are invalid, correct behavior is not guaranteed. Default value: False.
name Python str name prefixed to ops created by this function. Default value: None (i.e., 'build_factored_surrogate_posterior').

init_fn Python callable with signature initial_parameters = init_fn(seed).
apply_fn Python callable with signature instance = apply_fn(*parameters).

Examples

Consider a Gamma model with unknown parameters, expressed as a joint Distribution:

Root = tfd.JointDistributionCoroutine.Root
def model_fn():
  concentration = yield Root(tfd.Exponential(1.))
  rate = yield Root(tfd.Exponential(1.))
  y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate),
                       sample_shape=4)
model = tfd.JointDistributionCoroutine(model_fn)

Let's use variational inference to approximate the posterior over the data-generating parameters for some observed y. We'll build a surrogate posterior distribution by specifying the shapes of the latent rate and concentration parameters, and that both are constrained to be positive.

surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
  event_shape=model.event_shape_tensor()[:-1],  # Omit the observed `y`.
  bijector=[tfb.Softplus(),   # Rate is positive.
            tfb.Softplus()])  # Concentration is positive.

This creates a trainable joint distribution, defined by variables in surrogate_posterior.trainable_variables. We use fit_surrogate_posterior to fit this distribution by minimizing a divergence to the true posterior.

y = [0.2, 0.5, 0.3, 0.7]
losses = tfp.vi.fit_surrogate_posterior(
  lambda rate, concentration: model.log_prob([rate, concentration, y]),
  surrogate_posterior=surrogate_posterior,
  num_steps=100,
  optimizer=tf.optimizers.Adam(0.1),
  sample_size=10)

# After optimization, samples from the surrogate will approximate
# samples from the true posterior.
samples = surrogate_posterior.sample(100)
posterior_mean = [tf.reduce_mean(x) for x in samples]     # mean ~= [1.1, 2.1]
posterior_std = [tf.math.reduce_std(x) for x in samples]  # std  ~= [0.3, 0.8]

If we wanted to initialize the optimization at a specific location, we can specify initial parameters when we build the surrogate posterior. Note that these parameterize the distribution(s) over unconstrained values, so we need to transform our desired constrained locations using the inverse of the constraining bijector(s).

surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
  event_shape=tf.nest.map_fn(tf.shape, initial_loc),
  bijector={'concentration': tfb.Softplus(),   # Rate is positive.
            'rate': tfb.Softplus()}   # Concentration is positive.
  initial_parameters={
    'concentration': {'loc': tfb.Softplus().inverse(0.4), 'scale': 1e-2},
    'rate': {'loc': tfb.Softplus().inverse(0.2), 'scale': 1e-2} })