tfp.sts.build_factored_surrogate_posterior_stateless

Returns stateless functions for building a variational posterior.

The surrogate posterior consists of independent Normal distributions for each parameter with trainable loc and scale, transformed using the parameter's bijector to the appropriate support space for that parameter.

model An instance of StructuralTimeSeries representing a time-series model. This represents a joint distribution over time-series and their parameters with batch shape [b1, ..., bN].
batch_shape Batch shape (Python tuple, list, or int) of initial states to optimize in parallel. Default value: (). (i.e., just run a single optimization).
name Python str name prefixed to ops created by this function. Default value: None (i.e., 'build_factored_surrogate_posterior').

init_fn A function that takes in a stateless random seed and returns the parameters of the variational posterior.
build_surrogate_posterior_fn A function that takes in the parameters and returns a surrogate posterior distribution.

Examples

Assume we've built a structural time-series model:

  day_of_week = tfp.sts.Seasonal(
      num_seasons=7,
      observed_time_series=observed_time_series,
      name='day_of_week')
  local_linear_trend = tfp.sts.LocalLinearTrend(
      observed_time_series=observed_time_series,
      name='local_linear_trend')
  model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],
                      observed_time_series=observed_time_series)

To (statelessly) fit the model to data, we construct init_fn and build_surrogate_fn. init_fn constructs an initial set of parameters and build_surrogate_fn is passed into tfp.vi.fit_surrogate_posterior_stateless to optimize a variational bound.

  # This example only works in the JAX backend because it uses
  # `optax` for stateless optimizers.
  seed = tfp.random.sanitize_seed(jax.random.PRNGKey(0), salt='fit_stateless')
  init_seed, fit_seed, sample_seed = tfp.random.split_seed(seed, n=3)
  init_fn, build_surrogate_fn = (
      tfp.sts.build_factored_surrogate_posterior_stateless(model=model))
  initial_parameters = init_fn(init_seed)
  jd = model.joint_distribution(observed_time_series)
  final_parameters, loss_curve = tfp.vi.fit_surrogate_posterior_stateless(
    target_log_prob_fn=jd.log_prob,
    initial_parameters=initial_parameters,
    build_surrogate_posterior_fn=build_surrogate_fn,
    optimizer=optax.adam(1e-4),
    num_steps=200,
    seed=fit_seed)
  surrogate_posterior = build_surrogate_fn(final_parameters)
  posterior_samples = surrogate_posterior.sample(50, seed=sample_seed)