View source on GitHub
|
Use particle filtering to sample from the posterior over trajectories.
tfp.experimental.mcmc.infer_trajectories(
observations,
initial_state_prior,
transition_fn,
observation_fn,
num_particles,
initial_state_proposal=None,
proposal_fn=None,
resample_fn=tfp.experimental.mcmc.resample_systematic,
resample_criterion_fn=tfp.experimental.mcmc.ess_below_threshold,
unbiased_gradients=True,
rejuvenation_kernel_fn=None,
num_transitions_per_observation=1,
seed=None,
name=None
)
Each latent state is a Tensor or nested structure of Tensors, as defined
by the initial_state_prior.
The transition_fn and proposal_fn args, if specified, have signature
next_state_dist = fn(step, state), where step is an int Tensor index
of the current time step (beginning at zero), and state represents
the latent state at time step. The return value is a tfd.Distribution
instance over the state at time step + 1.
Similarly, the observation_fn has signature
observation_dist = observation_fn(step, state), where the return value
is a distribution over the value(s) observed at time step.
Args | |
|---|---|
observations
|
a (structure of) Tensors, each of shape
concat([[num_observation_steps, b1, ..., bN], event_shape]) with
optional batch dimensions b1, ..., bN.
|
initial_state_prior
|
a (joint) distribution over the initial latent state,
with optional batch shape [b1, ..., bN].
|
transition_fn
|
callable returning a (joint) distribution over the next latent state. |
observation_fn
|
callable returning a (joint) distribution over the current observation. |
num_particles
|
int Tensor number of particles.
|
initial_state_proposal
|
a (joint) distribution over the initial latent
state, with optional batch shape [b1, ..., bN]. If None, the initial
particles are proposed from the initial_state_prior.
Default value: None.
|
proposal_fn
|
callable returning a (joint) proposal distribution over the
next latent state. If None, the dynamics model is used (
proposal_fn == transition_fn).
Default value: None.
|
resample_fn
|
Python callable to generate the indices of resampled
particles, given their weights. Generally, one of
tfp.experimental.mcmc.resample_independent or
tfp.experimental.mcmc.resample_systematic, or any function
with the same signature, resampled_indices = f(log_probs, event_size, '
'sample_shape, seed).
Default: tfp.experimental.mcmc.resample_systematic.
|
resample_criterion_fn
|
optional Python callable with signature
do_resample = resample_criterion_fn(log_weights),
where log_weights is a float Tensor of shape
[b1, ..., bN, num_particles] containing log (unnormalized) weights for
all particles at the current step. The return value do_resample
determines whether particles are resampled at the current step. In the
case resample_criterion_fn==None, particles are resampled at every step.
The default behavior resamples particles when the current effective
sample size falls below half the total number of particles.
Default value: tfp.experimental.mcmc.ess_below_threshold.
|
unbiased_gradients
|
If True, use the stop-gradient
resampling trick of Scibior, Masrani, and Wood [2] to
correct for gradient bias introduced by the discrete resampling step. This
will generally increase the variance of stochastic gradients.
Default value: True.
|
rejuvenation_kernel_fn
|
optional Python callable with signature
transition_kernel = rejuvenation_kernel_fn(target_log_prob_fn)
where target_log_prob_fn is a provided callable evaluating
p(x[t] | y[t], x[t-1]) at each step t, and transition_kernel
should be an instance of tfp.mcmc.TransitionKernel.
Default value: None. |
num_transitions_per_observation
|
scalar Tensor positive int number of
state transitions between regular observation points. A value of 1
indicates that there is an observation at every timestep,
2 that every other step is observed, and so on. Values greater than 1
may be used with an appropriately-chosen transition function to
approximate continuous-time dynamics. The initial and final steps
(steps 0 and num_timesteps - 1) are always observed.
Default value: None.
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
name
|
Python str name for ops created by this method.
Default value: None (i.e., 'infer_trajectories').
|
Returns | |
|---|---|
trajectories
|
a (structure of) Tensor(s) matching the latent state, each
of shape
concat([[num_timesteps, num_particles, b1, ..., bN], event_shape]),
representing unbiased samples from the posterior distribution
p(latent_states | observations).
|
incremental_log_marginal_likelihoods
|
float Tensor of shape
[num_observation_steps, b1, ..., bN],
giving the natural logarithm of an unbiased estimate of
p(observations[t] | observations[:t]) at each timestep t. Note that
(by Jensen's inequality)
this is smaller in expectation than the true
log p(observations[t] | observations[:t]).
|
Examples
Tracking unknown position and velocity: Let's consider tracking an object
moving in a one-dimensional space. We'll define a dynamical system
by specifying an initial_state_prior, a transition_fn,
and observation_fn.
The structure of the latent state space is determined by the prior distribution. Here, we'll define a state space that includes the object's current position and velocity:
initial_state_prior = tfd.JointDistributionNamed({
'position': tfd.Normal(loc=0., scale=1.),
'velocity': tfd.Normal(loc=0., scale=0.1)})
The transition_fn specifies the evolution of the system. It should
return a distribution over latent states of the same structure as the prior.
Here, we'll assume that the position evolves according to the velocity,
with a small random drift, and the velocity also changes slowly, following
a random drift:
def transition_fn(_, previous_state):
return tfd.JointDistributionNamed({
'position': tfd.Normal(
loc=previous_state['position'] + previous_state['velocity'],
scale=0.1),
'velocity': tfd.Normal(loc=previous_state['velocity'], scale=0.01)})
The observation_fn specifies the process by which the system is observed
at each time step. Let's suppose we observe only a noisy version of the =
current position.
def observation_fn(_, state):
return tfd.Normal(loc=state['position'], scale=0.1)
Now let's track our object. Suppose we've been given observations
corresponding to an initial position of 0.4 and constant velocity of 0.01:
# Generate simulated observations.
observed_positions = tfd.Normal(loc=tf.linspace(0.4, 0.8, 0.01),
scale=0.1).sample()
# Run particle filtering to sample plausible trajectories.
(trajectories, # {'position': [40, 1000], 'velocity': [40, 1000]}
lps) = tfp.experimental.mcmc.infer_trajectories(
observations=observed_positions,
initial_state_prior=initial_state_prior,
transition_fn=transition_fn,
observation_fn=observation_fn,
num_particles=1000)
For all i, trajectories['position'][:, i] is a sample from the
posterior over position sequences, given the observations:
p(state[0:T] | observations[0:T]). Often, the sampled trajectories
will be highly redundant in their earlier timesteps, because most
of the initial particles have been discarded through resampling
(this problem is known as 'particle degeneracy'; see section 3.5 of
[Doucet and Johansen][1]).
In such cases it may be useful to also consider the series of filtering
distributions p(state[t] | observations[:t]), in which each latent state
is inferred conditioned only on observations up to that point in time; these
may be computed using tfp.mcmc.experimental.particle_filter.
References
[1] Arnaud Doucet and Adam M. Johansen. A tutorial on particle filtering and smoothing: Fifteen years later. Handbook of nonlinear filtering, 12(656-704), 2009. https://www.stats.ox.ac.uk/~doucet/doucet_johansen_tutorialPF2011.pdf [2] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle Filtering without Modifying the Forward Pass. arXiv preprint arXiv:2106.10314, 2021. https://arxiv.org/abs/2106.10314
View source on GitHub