tfp.sts.forecast

Construct predictive distribution over future observations.

Used in the notebooks

Used in the tutorials

Given samples from the posterior over parameters, return the predictive distribution over future observations for num_steps_forecast timesteps.

model An instance of StructuralTimeSeries representing a time-series model. This represents a joint distribution over time-series and their parameters with batch shape [b1, ..., bN].
observed_time_series float Tensor of shape concat([sample_shape, model.batch_shape, [num_timesteps, 1]]) where sample_shape corresponds to i.i.d. observations, and the trailing [1] dimension may (optionally) be omitted if num_timesteps > 1. Any NaNs are interpreted as missing observations; missingness may be also be explicitly specified by passing a tfp.sts.MaskedTimeSeries instance.
parameter_samples Python list of Tensors representing posterior samples of model parameters, with shapes [concat([[num_posterior_draws], param.prior.batch_shape, param.prior.event_shape]) for param in model.parameters]. This may optionally also be a map (Python dict) of parameter names to Tensor values.
num_steps_forecast scalar int Tensor number of steps to forecast.
include_observation_noise Python bool indicating whether the forecast distribution should include uncertainty from observation noise. If True, the forecast is over future observations, if False, the forecast is over future values of the latent noise-free time series. Default value: True.

forecast_dist a tfd.MixtureSameFamily instance with event shape [num_steps_forecast, 1] and batch shape concat([sample_shape, model.batch_shape]), with num_posterior_draws mixture components.

Examples

Suppose we've built a model and fit it to data using HMC:

  day_of_week = tfp.sts.Seasonal(
      num_seasons=7,
      observed_time_series=observed_time_series,
      name='day_of_week')
  local_linear_trend = tfp.sts.LocalLinearTrend(
      observed_time_series=observed_time_series,
      name='local_linear_trend')
  model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],
                      observed_time_series=observed_time_series)

  samples, kernel_results = tfp.sts.fit_with_hmc(model, observed_time_series)

Passing the posterior samples into forecast, we construct a forecast distribution:

  forecast_dist = tfp.sts.forecast(model, observed_time_series,
                                   parameter_samples=samples,
                                   num_steps_forecast=50)

  forecast_mean = forecast_dist.mean()[..., 0]  # shape: [50]
  forecast_scale = forecast_dist.stddev()[..., 0]  # shape: [50]
  forecast_samples = forecast_dist.sample(10)[..., 0]  # shape: [10, 50]

If using variational inference instead of HMC, we'd construct a forecast using samples from the variational posterior:

  surrogate_posterior = tfp.sts.build_factored_surrogate_posterior(
    model=model)
  loss_curve = tfp.vi.fit_surrogate_posterior(
    target_log_prob_fn=model.joint_distribution(observed_time_series).log_prob,
    surrogate_posterior=surrogate_posterior,
    optimizer=tf.optimizers.Adam(learning_rate=0.1),
    num_steps=200)
  samples = surrogate_posterior.sample(30)

  forecast_dist = tfp.sts.forecast(model, observed_time_series,
                                   parameter_samples=samples,
                                   num_steps_forecast=50)

We can visualize the forecast by plotting:

  from matplotlib import pylab as plt
  def plot_forecast(observed_time_series,
                    forecast_mean,
                    forecast_scale,
                    forecast_samples):
    plt.figure(figsize=(12, 6))

    num_steps = observed_time_series.shape[-1]
    num_steps_forecast = forecast_mean.shape[-1]
    num_steps_train = num_steps - num_steps_forecast

    c1, c2 = (0.12, 0.47, 0.71), (1.0, 0.5, 0.05)
    plt.plot(np.arange(num_steps), observed_time_series,
             lw=2, color=c1, label='ground truth')

    forecast_steps = np.arange(num_steps_train,
                     num_steps_train+num_steps_forecast)
    plt.plot(forecast_steps, forecast_samples.T, lw=1, color=c2, alpha=0.1)
    plt.plot(forecast_steps, forecast_mean, lw=2, ls='--', color=c2,
             label='forecast')
    plt.fill_between(forecast_steps,
                     forecast_mean - 2 * forecast_scale,
                     forecast_mean + 2 * forecast_scale, color=c2, alpha=0.2)

    plt.xlim([0, num_steps])
    plt.legend()

  plot_forecast(observed_time_series,
                forecast_mean=forecast_mean,
                forecast_scale=forecast_scale,
                forecast_samples=forecast_samples)