View source on GitHub |
Samples from the joint distribution of a linear Gaussian state-space model.
tfp.experimental.parallel_filter.sample_walk(
transition_matrix,
transition_mean,
transition_scale_tril,
observation_matrix,
observation_mean,
observation_scale_tril,
initial_mean,
initial_scale_tril,
seed=None
)
This method draws samples from the joint prior distribution on latent and
observed variables in a linear Gaussian state-space model. The sampling is
parallelized over timesteps, so that sampling a sequence of length
num_timesteps
requires only O(log(num_timesteps))
sequential steps.
As with a naive sequential implementation, the total FLOP count scales
linearly in num_timesteps
(as O(T + T/2 + T/4 + ...) = O(T)
), so this
approach does not require extra resources in an asymptotic sense. However, it
likely has a somewhat larger constant factor, so a sequential sampler
may be preferred when throughput rather than latency is the highest priority.
Args | |
---|---|
transition_matrix
|
float Tensor of shape
[num_timesteps, B1, .., BN, latent_size, latent_size] .
|
transition_mean
|
float Tensor of shape
[num_timesteps, B1, .., BN, latent_size] .
|
transition_scale_tril
|
float Tensor of shape
[num_timesteps, B1, .., BN, latent_size, latent_size] .
|
observation_matrix
|
float Tensor of shape
[num_timesteps, B1, .., BN, observation_size, latent_size] .
|
observation_mean
|
float Tensor of shape
[num_timesteps, B1, .., BN, observation_size] .
|
observation_scale_tril
|
float Tensor of shape
[num_timesteps, B1, .., BN, observation_size, observation_size] .
|
initial_mean
|
float Tensor of shape
[B1, .., BN, latent_size] .
|
initial_scale_tril
|
float Tensor of shape
[B1, .., BN, latent_size, latent_size] .
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
x
|
float Tensor of shape [num_timesteps, B1, .., BN, latent_size] .
|
y
|
float Tensor of shape [num_timesteps, B1, .., BN, observation_size] .
|
Mathematical Details
The assumed model consists of latent state vectors
x[:num_timesteps, :latent_size]
and corresponding observed values
y[:num_timesteps, :observation_size]
, governed by the following dynamics:
x[0] ~ MultivariateNormal(mean=initial_mean, scale_tril=initial_scale_tril)
for t in range(num_timesteps - 1):
x[t + 1] ~ MultivariateNormal(mean=matmul(transition_matrix[t],
x[t]) + transition_mean[t],
scale_tril=transition_scale_tril[t])
# Observed values `y[:num_timesteps]` defined at all timesteps.
y ~ MultivariateNormal(mean=matmul(observation_matrix, x) + observation_mean,
scale_tril=observation_scale_tril)
Tensor layout
Tensor
arguments are expected to have num_timesteps
as their leftmost
axis, preceding any batch dimensions. This layout is used
for internal computations, so providing arguments in this form avoids the
need for potentially-spurious transposition. The returned Tensor
s also
follow this layout, for the same reason. Note that this differs from the
layout mandated by the tfd.Distribution
API (and exposed by tfd.LinearGaussianStateSpaceModel
), in which the time
axis is to the right of any batch dimensions; it is the caller's
responsibility to perform any needed transpositions.
Note that this method takes scale_tril
matrices specifying the Cholesky
factors of covariance matrices, in contrast to
tfp.experimental.parallel_filter.kalman_filter
, which takes the covariance
matrices directly. This is to avoid redundant factorization, since the
sampling process uses Cholesky factors natively, while the filtering updates
we implement require covariance matrices. In addition, taking scale_tril
matrices directly ensures that sampling is well-defined even when one or more
components of the model are deterministic (scale_tril=zeros([...])
).
Tensor arguments may be specified with partial batch shape, i.e., with
shape prefix [num_timesteps, Bk, ..., BN]
for k > 1
. They will be
internally reshaped and broadcast to the full batch shape prefix
[num_timesteps, B1, ..., BN]
.