|View source on GitHub|
The SNAPER criterion from .
tfp.experimental.mcmc.snaper_criterion( previous_state, proposed_state, accept_prob, trajectory_length, direction, state_mean=None, state_mean_weight=0.0, validate_args=False, experimental_shard_axis_names=None, experimental_reduce_chain_axis_names=None )
SNAPER stands for Squared Norm Along Principal component ESJD Rate:
SNAPER = E[(((x' - E[x'])^T p)**2 - ((x' - E[x])^T p)**2)**2 / trajectory_length],
x is the previous chain state,
x' is the next chain state, and
is a unit vector (the
direction argument). Both expectations are with
respect to the chain's stationary distribution. In practice, the inner
expectation is replaced by the empirical mean across chains, so computing this
criterion requires that at least 2 chains are present unless
state_mean_weight are set. The outer expectation is computed by the caller
(e.g. in the
This can be thought of as the standard expected squared jump distance (ESJD) criterion, except that the jump distance is computed in the space of squared projections onto a vector.
direction vector is typically chosen to be an approximation to the first
principal component of the state covariance matrix.
state_mean_weight can be used to supplement the empirical
means as follows:
E[x] ≈ (1 - state_mean_weight) * x.mean() + state_mean_weight * state_mean.
||The value of the SNAPER criterion.|
: Sountsov, P. & Hoffman, M. (2021). Focusing on Difficult Directions for Learning HMC Trajectory Lengths. <https://arxiv.org/abs/2110.11576>