View source on GitHub
|
The SNAPER criterion from [1].
tfp.experimental.mcmc.snaper_criterion(
previous_state,
proposed_state,
accept_prob,
trajectory_length,
direction,
state_mean=None,
state_mean_weight=0.0,
validate_args=False,
experimental_shard_axis_names=None,
experimental_reduce_chain_axis_names=None
)
SNAPER stands for Squared Norm Along Principal component ESJD Rate:
SNAPER = E[(((x' - E[x'])^T p)**2 - ((x' - E[x])^T p)**2)**2 /
trajectory_length],
where x is the previous chain state, x' is the next chain state, and p
is a unit vector (the direction argument). Both expectations are with
respect to the chain's stationary distribution. In practice, the inner
expectation is replaced by the empirical mean across chains, so computing this
criterion requires that at least 2 chains are present unless state_mean and
state_mean_weight are set. The outer expectation is computed by the caller
(e.g. in the GradientBasedTrajectoryLengthAdaptation kernel).
This can be thought of as the standard expected squared jump distance (ESJD) criterion, except that the jump distance is computed in the space of squared projections onto a vector.
The direction vector is typically chosen to be an approximation to the first
principal component of the state covariance matrix.
state_mean and state_mean_weight can be used to supplement the empirical
means as follows:
E[x] ≈ (1 - state_mean_weight) * x.mean() + state_mean_weight * state_mean.
Returns | |
|---|---|
snaper
|
The value of the SNAPER criterion. |
References
[1]: Sountsov, P. & Hoffman, M. (2021). Focusing on Difficult Directions for Learning HMC Trajectory Lengths. <https://arxiv.org/abs/2110.11576>
View source on GitHub