View source on GitHub
|
Construct predictive distribution over future observations.
tfp.substrates.numpy.sts.forecast(
model,
observed_time_series,
parameter_samples,
num_steps_forecast,
include_observation_noise=True
)
Given samples from the posterior over parameters, return the predictive distribution over future observations for num_steps_forecast timesteps.
Args | |
|---|---|
model
|
An instance of StructuralTimeSeries representing a
time-series model. This represents a joint distribution over
time-series and their parameters with batch shape [b1, ..., bN].
|
observed_time_series
|
float Tensor of shape
concat([sample_shape, model.batch_shape, [num_timesteps, 1]]) where
sample_shape corresponds to i.i.d. observations, and the trailing [1]
dimension may (optionally) be omitted if num_timesteps > 1. Any NaNs
are interpreted as missing observations; missingness may be also be
explicitly specified by passing a tfp.sts.MaskedTimeSeries instance.
|
parameter_samples
|
Python list of Tensors representing posterior samples
of model parameters, with shapes [concat([[num_posterior_draws],
param.prior.batch_shape, param.prior.event_shape]) for param in
model.parameters]. This may optionally also be a map (Python dict) of
parameter names to Tensor values.
|
num_steps_forecast
|
scalar int Tensor number of steps to forecast.
|
include_observation_noise
|
Python bool indicating whether the forecast
distribution should include uncertainty from observation noise. If True,
the forecast is over future observations, if False, the forecast is over
future values of the latent noise-free time series.
Default value: True.
|
Returns | |
|---|---|
forecast_dist
|
a tfd.MixtureSameFamily instance with event shape
[num_steps_forecast, 1] and batch shape
concat([sample_shape, model.batch_shape]), with num_posterior_draws
mixture components.
|
Examples
Suppose we've built a model and fit it to data using HMC:
day_of_week = tfp.sts.Seasonal(
num_seasons=7,
observed_time_series=observed_time_series,
name='day_of_week')
local_linear_trend = tfp.sts.LocalLinearTrend(
observed_time_series=observed_time_series,
name='local_linear_trend')
model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],
observed_time_series=observed_time_series)
samples, kernel_results = tfp.sts.fit_with_hmc(model, observed_time_series)
Passing the posterior samples into forecast, we construct a forecast
distribution:
forecast_dist = tfp.sts.forecast(model, observed_time_series,
parameter_samples=samples,
num_steps_forecast=50)
forecast_mean = forecast_dist.mean()[..., 0] # shape: [50]
forecast_scale = forecast_dist.stddev()[..., 0] # shape: [50]
forecast_samples = forecast_dist.sample(10)[..., 0] # shape: [10, 50]
If using variational inference instead of HMC, we'd construct a forecast using samples from the variational posterior:
surrogate_posterior = tfp.sts.build_factored_surrogate_posterior(
model=model)
loss_curve = tfp.vi.fit_surrogate_posterior(
target_log_prob_fn=model.joint_distribution(observed_time_series).log_prob,
surrogate_posterior=surrogate_posterior,
optimizer=tf.optimizers.Adam(learning_rate=0.1),
num_steps=200)
samples = surrogate_posterior.sample(30)
forecast_dist = tfp.sts.forecast(model, observed_time_series,
parameter_samples=samples,
num_steps_forecast=50)
We can visualize the forecast by plotting:
from matplotlib import pylab as plt
def plot_forecast(observed_time_series,
forecast_mean,
forecast_scale,
forecast_samples):
plt.figure(figsize=(12, 6))
num_steps = observed_time_series.shape[-1]
num_steps_forecast = forecast_mean.shape[-1]
num_steps_train = num_steps - num_steps_forecast
c1, c2 = (0.12, 0.47, 0.71), (1.0, 0.5, 0.05)
plt.plot(np.arange(num_steps), observed_time_series,
lw=2, color=c1, label='ground truth')
forecast_steps = np.arange(num_steps_train,
num_steps_train+num_steps_forecast)
plt.plot(forecast_steps, forecast_samples.T, lw=1, color=c2, alpha=0.1)
plt.plot(forecast_steps, forecast_mean, lw=2, ls='--', color=c2,
label='forecast')
plt.fill_between(forecast_steps,
forecast_mean - 2 * forecast_scale,
forecast_mean + 2 * forecast_scale, color=c2, alpha=0.2)
plt.xlim([0, num_steps])
plt.legend()
plot_forecast(observed_time_series,
forecast_mean=forecast_mean,
forecast_scale=forecast_scale,
forecast_samples=forecast_samples)
View source on GitHub