tfp.experimental.mcmc.particle_filter

Samples a series of particles representing filtered latent states.

The particle filter samples from the sequence of "filtering" distributions p(state[t] | observations[:t]) over latent states: at each point in time, this is the distribution conditioned on all observations up to that time. Because particles may be resampled, a particle at time t may be different from the particle with the same index at time t + 1. To reconstruct trajectories by tracing back through the resampling process, see tfp.mcmc.experimental.reconstruct_trajectories.

Each latent state is a Tensor or nested structure of Tensors, as defined by the initial_state_prior.

The transition_fn and proposal_fn args, if specified, have signature next_state_dist = fn(step, state), where step is an int Tensor index of the current time step (beginning at zero), and state represents the latent state at time step. The return value is a tfd.Distribution instance over the state at time step + 1.

Similarly, the observation_fn has signature observation_dist = observation_fn(step, state), where the return value is a distribution over the value(s) observed at time step.

observations a (structure of) Tensors, each of shape concat([[num_observation_steps, b1, ..., bN], event_shape]) with optional batch dimensions b1, ..., bN.
initial_state_prior a (joint) distribution over the initial latent state, with optional batch shape [b1, ..., bN].
transition_fn callable returning a (joint) distribution over the next latent state.
observation_fn callable returning a (joint) distribution over the current observation.
num_particles int Tensor number of particles.
initial_state_proposal a (joint) distribution over the initial latent state, with optional batch shape [b1, ..., bN]. If None, the initial particles are proposed from the initial_state_prior. Default value: None.
proposal_fn callable returning a (joint) proposal distribution over the next latent state. If None, the dynamics model is used ( proposal_fn == transition_fn). Default value: None.
resample_fn Python callable to generate the indices of resampled particles, given their weights. Generally, one of tfp.experimental.mcmc.resample_independent or tfp.experimental.mcmc.resample_systematic, or any function with the same signature, resampled_indices = f(log_probs, event_size, ' 'sample_shape, seed). Default: tfp.experimental.mcmc.resample_systematic.
resample_criterion_fn optional Python callable with signature do_resample = resample_criterion_fn(log_weights), where log_weights is a float Tensor of shape [b1, ..., bN, num_particles] containing log (unnormalized) weights for all particles at the current step. The return value do_resample determines whether particles are resampled at the current step. In the case resample_criterion_fn==None, particles are resampled at every step. The default behavior resamples particles when the current effective sample size falls below half the total number of particles. Default value: tfp.experimental.mcmc.ess_below_threshold.
unbiased_gradients If True, use the stop-gradient resampling trick of Scibior, Masrani, and Wood [1] to correct for gradient bias introduced by the discrete resampling step. This will generally increase the variance of stochastic gradients. Default value: True.
rejuvenation_kernel_fn optional Python callable with signature transition_kernel = rejuvenation_kernel_fn(target_log_prob_fn) where target_log_prob_fn is a provided callable evaluating p(x[t] | y[t], x[t-1]) at each step t, and transition_kernel should be an instance of tfp.mcmc.TransitionKernel. Default value: None.
num_transitions_per_observation scalar Tensor positive int number of state transitions between regular observation points. A value of 1 indicates that there is an observation at every timestep, 2 that every other step is observed, and so on. Values greater than 1 may be used with an appropriately-chosen transition function to approximate continuous-time dynamics. The initial and final steps (steps 0 and num_timesteps - 1) are always observed. Default value: None.
trace_fn Python callable defining the values to be traced at each step, with signature traced_values = trace_fn(weighted_particles, results) in which the first argument is an instance of tfp.experimental.mcmc.WeightedParticles and the second an instance of SequentialMonteCarloResults tuple, and the return value is a structure of Tensors. Default value: lambda s, r: (s.particles, s.log_weights, r.parent_indices, r.incremental_log_marginal_likelihood)
trace_criterion_fn optional Python callable with signature trace_this_step = trace_criterion_fn(weighted_particles, results) taking the same arguments as trace_fn and returning a boolean Tensor. If None, only values from the final step are returned. Default value: lambda *_: True (trace every step).
static_trace_allocation_size Optional Python int size of trace to allocate statically. This should be an upper bound on the number of steps traced and is used only when the length cannot be statically inferred (for example, if a trace_criterion_fn is specified). It is primarily intended for contexts where static shapes are required, such as in XLA-compiled code. Default value: None.
parallel_iterations Passed to the internal tf.while_loop. Default value: 1.
seed PRNG seed; see tfp.random.sanitize_seed for details.
name Python str name for ops created by this method. Default value: None (i.e., 'particle_filter').

traced_results A structure of Tensors as returned by trace_fn. If trace_criterion_fn==None, this is computed from the final step; otherwise, each Tensor will have initial dimension num_steps_traced and stacks the traced results across all steps.

References

[1] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle Filtering without Modifying the Forward Pass. arXiv preprint arXiv:2106.10314, 2021. https://arxiv.org/abs/2106.10314