Google I/O returns May 18-20! Reserve space and build your schedule Register now


Samples from the joint distribution of a linear Gaussian state-space model.

This method draws samples from the joint prior distribution on latent and observed variables in a linear Gaussian state-space model. The sampling is parallelized over timesteps, so that sampling a sequence of length num_timesteps requires only O(log(num_timesteps)) sequential steps.

As with a naive sequential implementation, the total FLOP count scales linearly in num_timesteps (as O(T + T/2 + T/4 + ...) = O(T)), so this approach does not require extra resources in an asymptotic sense. However, it likely has a somewhat larger constant factor, so a sequential sampler may be preferred when throughput rather than latency is the highest priority.

transition_matrix float Tensor of shape [num_timesteps, B1, .., BN, latent_size, latent_size].
transition_mean float Tensor of shape [num_timesteps, B1, .., BN, latent_size].
transition_scale_tril float Tensor of shape [num_timesteps, B1, .., BN, latent_size, latent_size].
observation_matrix float Tensor of shape [num_timesteps, B1, .., BN, observation_size, latent_size].
observation_mean float Tensor of shape [num_timesteps, B1, .., BN, observation_size].
observation_scale_tril float Tensor of shape [num_timesteps, B1, .., BN, observation_size, observation_size].
initial_mean float Tensor of shape [B1, .., BN, latent_size].
initial_scale_tril float Tensor of shape [B1, .., BN, latent_size, latent_size].
seed Python int seed for random ops.

x float Tensor of shape [num_timesteps, B1, .., BN, latent_size].
y float Tensor of shape [num_timesteps, B1, .., BN, observation_size].

Mathematical Details

The assumed model consists of latent state vectors x[:num_timesteps, :latent_size] and corresponding observed values y[:num_timesteps, :observation_size], governed by the following dynamics:

x[0] ~ MultivariateNormal(mean=initial_mean, scale_tril=initial_scale_tril)
for t in range(num_timesteps - 1):
  x[t + 1] ~ MultivariateNormal(mean=matmul(transition_matrix[t],
                                            x[t]) + transition_mean[t],
# Observed values `y[:num_timesteps]` defined at all timesteps.
y ~ MultivariateNormal(mean=matmul(observation_matrix, x) + observation_mean,

Tensor layout

Tensor arguments are expected to have num_timesteps as their leftmost axis, preceding any batch dimensions. This layout is used for internal computations, so providing arguments in this form avoids the need for potentially-spurious transposition. The returned Tensors also follow this layout, for the same reason. Note that this differs from the layout mandated by the tfd.Distribution API (and exposed by tfd.LinearGaussianStateSpaceModel), in which the time axis is to the right of any batch dimensions; it is the caller's responsibility to perform any needed transpositions.

Note that this method takes scale_tril matrices specifying the Cholesky factors of covariance matrices, in contrast to tfp.experimental.parallel_filter.kalman_filter, which takes the covariance matrices directly. This is to avoid redundant factorization, since the sampling process uses Cholesky factors natively, while the filtering updates we implement require covariance matrices. In addition, taking scale_tril matrices directly ensures that sampling is well-defined even when one or more components of the model are deterministic (scale_tril=zeros([...])).

Tensor arguments may be specified with partial batch shape, i.e., with shape prefix [num_timesteps, Bk, ..., BN] for k > 1. They will be internally reshaped and broadcast to the full batch shape prefix [num_timesteps, B1, ..., BN].