View source on GitHub |
Samples a series of particles representing filtered latent states.
tfp.experimental.mcmc.particle_filter(
observations,
initial_state_prior,
transition_fn,
observation_fn,
num_particles,
initial_state_proposal=None,
proposal_fn=None,
resample_fn=tfp.experimental.mcmc.resample_systematic
,
resample_criterion_fn=tfp.experimental.mcmc.ess_below_threshold
,
unbiased_gradients=True,
rejuvenation_kernel_fn=None,
num_transitions_per_observation=1,
trace_fn=_default_trace_fn,
trace_criterion_fn=_always_trace,
static_trace_allocation_size=None,
parallel_iterations=1,
seed=None,
name=None
)
The particle filter samples from the sequence of "filtering" distributions
p(state[t] | observations[:t])
over latent
states: at each point in time, this is the distribution conditioned on all
observations up to that time. Because particles may be resampled, a particle
at time t
may be different from the particle with the same index at time
t + 1
. To reconstruct trajectories by tracing back through the resampling
process, see tfp.mcmc.experimental.reconstruct_trajectories
.
Each latent state is a Tensor
or nested structure of Tensor
s, as defined
by the initial_state_prior
.
The transition_fn
and proposal_fn
args, if specified, have signature
next_state_dist = fn(step, state)
, where step
is an int
Tensor
index
of the current time step (beginning at zero), and state
represents
the latent state at time step
. The return value is a tfd.Distribution
instance over the state at time step + 1
.
Similarly, the observation_fn
has signature
observation_dist = observation_fn(step, state)
, where the return value
is a distribution over the value(s) observed at time step
.
Args | |
---|---|
observations
|
a (structure of) Tensors, each of shape
concat([[num_observation_steps, b1, ..., bN], event_shape]) with
optional batch dimensions b1, ..., bN .
|
initial_state_prior
|
a (joint) distribution over the initial latent state,
with optional batch shape [b1, ..., bN] .
|
transition_fn
|
callable returning a (joint) distribution over the next latent state. |
observation_fn
|
callable returning a (joint) distribution over the current observation. |
num_particles
|
int Tensor number of particles.
|
initial_state_proposal
|
a (joint) distribution over the initial latent
state, with optional batch shape [b1, ..., bN] . If None , the initial
particles are proposed from the initial_state_prior .
Default value: None .
|
proposal_fn
|
callable returning a (joint) proposal distribution over the
next latent state. If None , the dynamics model is used (
proposal_fn == transition_fn ).
Default value: None .
|
resample_fn
|
Python callable to generate the indices of resampled
particles, given their weights. Generally, one of
tfp.experimental.mcmc.resample_independent or
tfp.experimental.mcmc.resample_systematic , or any function
with the same signature, resampled_indices = f(log_probs, event_size, '
'sample_shape, seed) .
Default: tfp.experimental.mcmc.resample_systematic .
|
resample_criterion_fn
|
optional Python callable with signature
do_resample = resample_criterion_fn(log_weights) ,
where log_weights is a float Tensor of shape
[b1, ..., bN, num_particles] containing log (unnormalized) weights for
all particles at the current step. The return value do_resample
determines whether particles are resampled at the current step. In the
case resample_criterion_fn==None , particles are resampled at every step.
The default behavior resamples particles when the current effective
sample size falls below half the total number of particles.
Default value: tfp.experimental.mcmc.ess_below_threshold .
|
unbiased_gradients
|
If True , use the stop-gradient
resampling trick of Scibior, Masrani, and Wood [1] to
correct for gradient bias introduced by the discrete resampling step. This
will generally increase the variance of stochastic gradients.
Default value: True .
|
rejuvenation_kernel_fn
|
optional Python callable with signature
transition_kernel = rejuvenation_kernel_fn(target_log_prob_fn)
where target_log_prob_fn is a provided callable evaluating
p(x[t] | y[t], x[t-1]) at each step t , and transition_kernel
should be an instance of tfp.mcmc.TransitionKernel .
Default value: None . |
num_transitions_per_observation
|
scalar Tensor positive int number of
state transitions between regular observation points. A value of 1
indicates that there is an observation at every timestep,
2 that every other step is observed, and so on. Values greater than 1
may be used with an appropriately-chosen transition function to
approximate continuous-time dynamics. The initial and final steps
(steps 0 and num_timesteps - 1 ) are always observed.
Default value: None .
|
trace_fn
|
Python callable defining the values to be traced at each step,
with signature traced_values = trace_fn(weighted_particles, results)
in which the first argument is an instance of
tfp.experimental.mcmc.WeightedParticles and the second an instance of
SequentialMonteCarloResults tuple, and the return value is a structure
of Tensor s.
Default value: lambda s, r: (s.particles, s.log_weights,
r.parent_indices, r.incremental_log_marginal_likelihood)
|
trace_criterion_fn
|
optional Python callable with signature
trace_this_step = trace_criterion_fn(weighted_particles, results) taking
the same arguments as trace_fn and returning a boolean Tensor . If
None , only values from the final step are returned.
Default value: lambda *_: True (trace every step).
|
static_trace_allocation_size
|
Optional Python int size of trace to
allocate statically. This should be an upper bound on the number of steps
traced and is used only when the length cannot be
statically inferred (for example, if a trace_criterion_fn is specified).
It is primarily intended for contexts where static shapes are required,
such as in XLA-compiled code.
Default value: None .
|
parallel_iterations
|
Passed to the internal tf.while_loop .
Default value: 1 .
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
name
|
Python str name for ops created by this method.
Default value: None (i.e., 'particle_filter' ).
|
References
[1] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle Filtering without Modifying the Forward Pass. arXiv preprint arXiv:2106.10314, 2021. https://arxiv.org/abs/2106.10314