ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tfp.experimental.parallel_filter.kalman_filter

Infers latent values using a parallel Kalman filter.

This method computes filtered marginal means and covariances of a linear Gaussian state-space model using a parallel message-passing algorithm, as described by Sarkka and Garcia-Fernandez [1]. The inference process is formulated as a prefix-sum problem that can be efficiently computed by tfp.math.scan_associative, so that inference for a time series of length num_timesteps requires only O(log(num_timesteps)) sequential steps.

As with a naive sequential implementation, the total FLOP count scales linearly in num_timesteps (as O(T + T/2 + T/4 + ...) = O(T)), so this approach does not require extra resources in an asymptotic sense. However, it likely has a somewhat larger constant factor, so a sequential filter may be preferred when throughput rather than latency is the highest priority.

transition_matrix float Tensor of shape [num_timesteps, B1, .., BN, latent_size, latent_size].
transition_mean float Tensor of shape [num_timesteps, B1, .., BN, latent_size].
transition_cov float Tensor of shape [num_timesteps, B1, .., BN, latent_size, latent_size].
observation_matrix float Tensor of shape [num_timesteps, B1, .., BN, observation_size, latent_size].
observation_mean float Tensor of shape [num_timesteps, B1, .., BN, observation_size].
observation_cov float Tensor of shape [num_timesteps, B1, .., BN, observation_size, observation_size].
initial_mean float Tensor of shape [B1, .., BN, latent_size].
initial_cov float Tensor of shape [B1, .., BN, latent_size, latent_size].
y float Tensor of shape [num_timesteps, B1, .., BN, observation_size].
mask float Tensor of shape [num_timesteps, B1, .., BN].
return_all Python bool, whether to compute log-likelihoods and predictive and observation distributions. If False, only filtered_means and filtered_covs are computed, and None is returned for the remaining values.

log_likelihoods float Tensor of shape [num_timesteps, B1, .., BN], such that log_likelihoods[t] = log p(y[t] | y[:t]).
filtered_means float Tensor of shape [num_timesteps, B1, .., BN, latent_size], such that filtered_means[t] == E[x[t] | y[:t + 1]].
filtered_covs float Tensor of shape [num_timesteps, B1, .., BN, latent_size, latent_size].
predictive_means float Tensor of shape [num_timesteps, B1, .., BN, latent_size], such that predictive_means[t] = E[x[t + 1] | y[:t + 1]].
predictive_covs float Tensor of shape [num_timesteps, B1, .., BN, latent_size, latent_size].
observation_means float Tensor of shape [num_timesteps, B1, .., BN, observation_size], such that observation_means[t] = E[y[t] | y[:t]]. observation_covs:float Tensor of shape [num_timesteps, B1, .., BN, observation_size, observation_size].

Mathematical Details

The assumed model consists of latent state vectors x[:num_timesteps, :latent_size] and corresponding observed values y[:num_timesteps, :observation_size], governed by the following dynamics:

x[0] ~ MultivariateNormal(mean=initial_mean, cov=initial_cov)
for t in range(num_timesteps - 1):
  x[t + 1] ~ MultivariateNormal(mean=matmul(transition_matrix[t],
                                            x[t]) + transition_mean[t],
                                cov=transition_cov[t])
# Observed values `y[:num_timesteps]` defined at all timesteps.
y ~ MultivariateNormal(mean=matmul(observation_matrix, x) + observation_mean,
                       cov=observation_cov)

Tensor layout

Tensor arguments are expected to have num_timesteps as their leftmost axis, preceding any batch dimensions. This layout is used for internal computations, so providing arguments in this form avoids the need for potentially-spurious transposition. The returned Tensors also follow this layout, for the same reason. Note that this differs from the layout mandated by the tfd.Distribution API (and exposed by tfd.LinearGaussianStateSpaceModel), in which the time axis is to the right of any batch dimensions; it is the caller's responsibility to perform any needed transpositions.

Tensor arguments may be specified with partial batch shape, i.e., with shape prefix [num_timesteps, Bk, ..., BN] for k > 1. They will be internally reshaped and broadcast to the full batch shape prefix [num_timesteps, B1, ..., BN].

References

[1] Simo Sarkka and Angel F. Garcia-Fernandez. Temporal Parallelization of Bayesian Smoothers. arXiv preprint arXiv:1905.13002, 2019. https://arxiv.org/abs/1905.13002