View source on GitHub |
Infers latent values using a parallel Kalman filter.
tfp.experimental.parallel_filter.kalman_filter(
transition_matrix,
transition_mean,
transition_cov,
observation_matrix,
observation_mean,
observation_cov,
initial_mean,
initial_cov,
y,
mask,
return_all=True
)
This method computes filtered marginal means and covariances of a linear
Gaussian state-space model using a parallel message-passing algorithm, as
described by Sarkka and Garcia-Fernandez [1]. The inference process is
formulated as a prefix-sum problem that can be efficiently computed by
tfp.math.scan_associative
, so that inference for a time series of length
num_timesteps
requires only O(log(num_timesteps))
sequential steps.
As with a naive sequential implementation, the total FLOP count scales
linearly in num_timesteps
(as O(T + T/2 + T/4 + ...) = O(T)
), so this
approach does not require extra resources in an asymptotic sense. However, it
likely has a somewhat larger constant factor, so a sequential filter may be
preferred when throughput rather than latency is the highest priority.
Mathematical Details
The assumed model consists of latent state vectors
x[:num_timesteps, :latent_size]
and corresponding observed values
y[:num_timesteps, :observation_size]
, governed by the following dynamics:
x[0] ~ MultivariateNormal(mean=initial_mean, cov=initial_cov)
for t in range(num_timesteps - 1):
x[t + 1] ~ MultivariateNormal(mean=matmul(transition_matrix[t],
x[t]) + transition_mean[t],
cov=transition_cov[t])
# Observed values `y[:num_timesteps]` defined at all timesteps.
y ~ MultivariateNormal(mean=matmul(observation_matrix, x) + observation_mean,
cov=observation_cov)
Tensor layout
Tensor
arguments are expected to have num_timesteps
as their leftmost
axis, preceding any batch dimensions. This layout is used
for internal computations, so providing arguments in this form avoids the
need for potentially-spurious transposition. The returned Tensor
s also
follow this layout, for the same reason. Note that this differs from the
layout mandated by the tfd.Distribution
API (and exposed by tfd.LinearGaussianStateSpaceModel
), in which the time
axis is to the right of any batch dimensions; it is the caller's
responsibility to perform any needed transpositions.
Tensor arguments may be specified with partial batch shape, i.e., with
shape prefix [num_timesteps, Bk, ..., BN]
for k > 1
. They will be
internally reshaped and broadcast to the full batch shape prefix
[num_timesteps, B1, ..., BN]
.
References
[1] Simo Sarkka and Angel F. Garcia-Fernandez. Temporal Parallelization of Bayesian Smoothers. arXiv preprint arXiv:1905.13002, 2019. https://arxiv.org/abs/1905.13002