ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


Use particle filtering to sample from the posterior over trajectories.

Each latent state is a Tensor or nested structure of Tensors, as defined by the initial_state_prior.

The transition_fn and proposal_fn args, if specified, have signature next_state_dist = fn(step, state), where step is an int Tensor index of the current time step (beginning at zero), and state represents the latent state at time step. The return value is a tfd.Distribution instance over the state at time step + 1.

Similarly, the observation_fn has signature observation_dist = observation_fn(step, state), where the return value is a distribution over the value(s) observed at time step.

observations a (structure of) Tensors, each of shape concat([[num_observation_steps, b1, ..., bN], event_shape]) with optional batch dimensions b1, ..., bN.
initial_state_prior a (joint) distribution over the initial latent state, with optional batch shape [b1, ..., bN].
transition_fn callable returning a (joint) distribution over the next latent state.
observation_fn callable returning a (joint) distribution over the current observation.
num_particles int Tensor number of particles.
initial_state_proposal a (joint) distribution over the initial latent state, with optional batch shape [b1, ..., bN]. If None, the initial particles are proposed from the initial_state_prior. Default value: None.
proposal_fn callable returning a (joint) proposal distribution over the next latent state. If None, the dynamics model is used ( proposal_fn == transition_fn). Default value: None.
resample_fn Python callable to generate the indices of resampled particles, given their weights. Generally, one of tfp.experimental.mcmc.resample_independent or tfp.experimental.mcmc.resample_systematic, or any function with the same signature, resampled_indices = f(log_probs, event_size, ' 'sample_shape, seed). Default: tfp.experimental.mcmc.resample_systematic.
resample_criterion_fn optional Python callable with signature do_resample = resample_criterion_fn(log_weights), where log_weights is a float Tensor of shape [b1, ..., bN, num_particles] containing log (unnormalized) weights for all particles at the current step. The return value do_resample determines whether particles are resampled at the current step. In the case resample_criterion_fn==None, particles are resampled at every step. The default behavior resamples particles when the current effective sample size falls below half the total number of particles. Default value: tfp.experimental.mcmc.ess_below_threshold.
unbiased_gradients If True, use the stop-gradient resampling trick of Scibior, Masrani, and Wood [2] to correct for gradient bias introduced by the discrete resampling step. This will generally increase the variance of stochastic gradients. Default value: True.
rejuvenation_kernel_fn optional Python callable with signature transition_kernel = rejuvenation_kernel_fn(target_log_prob_fn) where target_log_prob_fn is a provided callable evaluating p(x[t] | y[t], x[t-1]) at each step t, and transition_kernel should be an instance of tfp.mcmc.TransitionKernel. Default value: None.
num_transitions_per_observation scalar Tensor positive int number of state transitions between regular observation points. A value of 1 indicates that there is an observation at every timestep, 2 that every other step is observed, and so on. Values greater than 1 may be used with an appropriately-chosen transition function to approximate continuous-time dynamics. The initial and final steps (steps 0 and num_timesteps - 1) are always observed. Default value: None.
seed PRNG seed; see tfp.random.sanitize_seed for details.
name Python str name for ops created by this method. Default value: None (i.e., 'infer_trajectories').

trajectories a (structure of) Tensor(s) matching the latent state, each of shape concat([[num_timesteps, num_particles, b1, ..., bN], event_shape]), representing unbiased samples from the posterior distribution p(latent_states | observations).
incremental_log_marginal_likelihoods float Tensor of shape [num_observation_steps, b1, ..., bN], giving the natural logarithm of an unbiased estimate of p(observations[t] | observations[:t]) at each timestep t. Note that (by Jensen's inequality) this is smaller in expectation than the true log p(observations[t] | observations[:t]).


Tracking unknown position and velocity: Let's consider tracking an object moving in a one-dimensional space. We'll define a dynamical system by specifying an initial_state_prior, a transition_fn, and observation_fn.

The structure of the latent state space is determined by the prior distribution. Here, we'll define a state space that includes the object's current position and velocity:

initial_state_prior = tfd.JointDistributionNamed({
    'position': tfd.Normal(loc=0., scale=1.),
    'velocity': tfd.Normal(loc=0., scale=0.1)})

The transition_fn specifies the evolution of the system. It should return a distribution over latent states of the same structure as the prior. Here, we'll assume that the position evolves according to the velocity, with a small random drift, and the velocity also changes slowly, following a random drift:

def transition_fn(_, previous_state):
  return tfd.JointDistributionNamed({
      'position': tfd.Normal(
          loc=previous_state['position'] + previous_state['velocity'],
      'velocity': tfd.Normal(loc=previous_state['velocity'], scale=0.01)})

The observation_fn specifies the process by which the system is observed at each time step. Let's suppose we observe only a noisy version of the = current position.

  def observation_fn(_, state):
    return tfd.Normal(loc=state['position'], scale=0.1)

Now let's track our object. Suppose we've been given observations corresponding to an initial position of 0.4 and constant velocity of 0.01:

# Generate simulated observations.
observed_positions = tfd.Normal(loc=tf.linspace(0.4, 0.8, 0.01),

# Run particle filtering to sample plausible trajectories.
(trajectories,  # {'position': [40, 1000], 'velocity': [40, 1000]}
 lps) = tfp.experimental.mcmc.infer_trajectories(

For all i, trajectories['position'][:, i] is a sample from the posterior over position sequences, given the observations: p(state[0:T] | observations[0:T]). Often, the sampled trajectories will be highly redundant in their earlier timesteps, because most of the initial particles have been discarded through resampling (this problem is known as 'particle degeneracy'; see section 3.5 of [Doucet and Johansen][1]). In such cases it may be useful to also consider the series of filtering distributions p(state[t] | observations[:t]), in which each latent state is inferred conditioned only on observations up to that point in time; these may be computed using tfp.mcmc.experimental.particle_filter.


[1] Arnaud Doucet and Adam M. Johansen. A tutorial on particle filtering and smoothing: Fifteen years later. Handbook of nonlinear filtering, 12(656-704), 2009. [2] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle Filtering without Modifying the Forward Pass. arXiv preprint arXiv:2106.10314, 2021.