tfp.experimental.mcmc.chees_criterion

The ChEES criterion from [1].

ChEES stands for Change in the Estimator of the Expected Square.

ChEES = 1/4 E[(||x' - E[x]||**2 - ||x - E[x]||**2)**2],

where x is the previous chain state, x' is the next chain state, and ||.|| is the L2 norm. Both expectations are with respect to the chain's stationary distribution. In practice, the inner expectation is replaced by the empirical mean across chains, so computing this criterion requires that at least 2 chains are present. The outer expectation is computed by the caller (e.g. in the GradientBasedTrajectoryLengthAdaptation kernel).

This can be thought of as the standard expected squared jump distance (ESJD) criterion, except that the jump distance is computed in the space of centered squared L2 norms.

Unlike ChEES, regular ESJD is maximized by perfectly anticorrelated proposals, which can give excellent mean estimates but terrible variance estimates; maximizing ChEES should give good estimates across a wider range of types of posterior expectations.

previous_state (Possibly nested) floating point Tensor. The previous state of the HMC chain.
proposed_state (Possibly nested) floating point Tensor. The proposed state of the HMC chain.
accept_prob Floating Tensor. Probability of acceping the proposed state.
trajectory_length Floating Tensor. Mean trajectory length (not used in this criterion).
validate_args Whether to perform non-static argument validation.
experimental_shard_axis_names A structure of string names indicating how members of the state are sharded.
experimental_reduce_chain_axis_names A string or list of string names indicating which named chain axes to reduce over when computing the criterion.

chees The value of the ChEES criterion.

ValueError If accept_prob indicates that there are fewer than 2 chains.

References

[1]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme for Setting Trajectory Lengths in Hamiltonian Monte Carlo. <https://proceedings.mlr.press/v130/hoffman21a>