Applies an Extended Kalman Filter to observed data.
tfp.experimental.sequential.extended_kalman_filter(
observations,
initial_state_prior,
transition_fn,
observation_fn,
transition_jacobian_fn,
observation_jacobian_fn,
name=None
)
The Extended Kalman Filter is a nonlinear version
of the Kalman filter, in which the transition function is linearized by
first-order Taylor expansion around the current mean and covariance of the
state estimate.
Args |
observations
|
a (structure of) Tensor s, each of shape
concat([[num_timesteps, b1, ..., bN], [event_size]]) with scalar
event_size and optional batch dimensions b1, ..., bN .
|
initial_state_prior
|
a tfd.Distribution instance (typically
MultivariateNormal ) with event_shape equal to state_size and an
optional batch_shape of [b1, ..., bN ], representing the prior over the
state.
|
transition_fn
|
a Python callable that accepts (batched) vectors of length
state_size , and returns a tfd.Distribution instance, typically a
MultivariateNormal , representing the state transition and covariance.
|
observation_fn
|
a Python callable that accepts a (batched) vector of
length state_size and returns a tfd.Distribution instance, typically
a MultivariateNormal representing the observation model and covariance.
|
transition_jacobian_fn
|
a Python callable that accepts a (batched) vector
of length state_size and returns a (batched) matrix of shape
[state_size, state_size] , representing the Jacobian of transition_fn .
|
observation_jacobian_fn
|
a Python callable that accepts a (batched) vector
of length state_size and returns a (batched) matrix of size
[state_size, event_size] , representing the Jacobian of observation_fn .
|
name
|
Python str name for ops created by this method.
Default value: None (i.e., 'extended_kalman_filter' ).
|
Returns |
filtered_mean
|
a (structure of) Tensor (s) of shape
concat([[num_timesteps, b1, ..., bN], [state_size]]) . The mean of the
filtered state estimate.
|
filtered_cov
|
a (structure of) Tensor (s) of shape
concat([[num_timesteps, b1, ..., bN], [state_size, state_size]]) .
The covariance of the filtered state estimate.
|
predicted_mean
|
a (structure of) Tensor (s) of shape
concat([[num_timesteps, b1, ..., bN], [state_size]]) . The prior
predicted means of the state.
|
predicted_cov
|
a (structure of) Tensor (s) of shape
concat([[num_timesteps, b1, ..., bN], [state_size, state_size]])
The prior predicted covariances of the state estimate.
|
observation_mean
|
a (structure of) Tensor (s) of shape
concat([[num_timesteps, b1, ..., bN], [event_size]]) . The prior
predicted mean of observations.
|
observation_cov
|
a (structure of) Tensor (s) of shape
concat([[num_timesteps, b1, ..., bN], [event_size, event_size]]) . The
prior predicted covariance of observations.
|
log_marginal_likelihood
|
a (structure of) Tensor (s) of shape
[num_timesteps, b1, ..., bN] . Log likelihood of the observations with
respect to the observation.
|
timestep
|
a (structure of) integer Tensor (s) of shape
[num_timesteps, b1, ..., bN] containing time indices.
|
Examples
Estimate a simple nonlinear system: Let's consider a system defined by
the transition equation y_{t+1} = y_t - 0.1 * w_t **3
and w_{t+1} = w_t
,
such that the state can be expressed as [y, w]
. The transition_fn
and
transition_jacobian_fn
can be expressed as:
def transition_fn(x):
return tfd.MultivariateNormalDiag(
tf.stack(
[x[..., 0] - 0.1 * x[..., 1]**3, x[..., 1]], axis=-1),
scale_diag=[0.7, 0.2])
def transition_jacobian_fn(x):
return tf.reshape(
tf.stack(
[1. - 0.1 * x[..., 1]**3, -0.3 * x[..., 1]**2,
tf.zeros(x.shape[:-1]), tf.ones(x.shape[:-1])], axis=-1),
[2, 2])
Assume we take noisy measurements of only the first element of the state.
observation_fn = lambda x: tfd.MultivariateNormalDiag(
x[..., :1], scale_diag=[1.])
observation_jacobian_fn = lambda x: [[1., 0.]]
We define a prior over the initial state, and use it to synthesize data for
20 steps of the process.
initial_state_prior = tfd.MultivariateNormalDiag(0., scale_diag=[1., 0.3])
x = [np.zeros((2,), dtype=np.float32)]
for t in range(20):
x.append(transition_fn(x[-1]).sample())
x = tf.stack(x)
observations=observation_fn(x).sample()
Run the Extended Kalman filter on the synthesized observed data.
results = tfp.experimental.sequential.extended_kalman_filter(
observations,
initial_state_prior,
transition_fn,
observation_fn,
transition_jacobian_fn,
observation_jacobian_fn)