tfp.experimental.mcmc.augment_with_state_history

Decorates a transition or proposal fn to track state history.

For example usage, see tfp.experimental.mcmc.augment_prior_with_state_history.

fn Python callable to wrap, having signature new_state_dist = fn(step, state_with_history, **kwargs) where state_with_history is a StateWithHistory namedtuple.

augmented_fn Python callable wrapping fn, having signature new_state_with_history_dist = augmented_fn(step, state_with_history, **kwargs). The return value is a tfd.JointDistributionNamed instance overtfp.experimental.mcmc.StateWithHistory namedtuples, in which the state_history component is rotated to discard the (previously-oldest) state at the initial position and append the new state at the final position.