tfp.substrates.jax.mcmc.effective_sample_size

Estimate a lower bound on effective sample size for each independent chain.

Roughly speaking, "effective sample size" (ESS) is the size of an iid sample with the same variance as state.

More precisely, given a stationary sequence of possibly correlated random variables X_1, X_2, ..., X_N, identically distributed, ESS is the number such that

Variance{ N**-1 * Sum{X_i} } = ESS**-1 * Variance{ X_1 }.

If the sequence is uncorrelated, ESS = N. If the sequence is positively auto-correlated, ESS will be less than N. If there are negative correlations, then ESS can exceed N.

Some math shows that, with R_k the auto-correlation sequence, R_k := Covariance{X_1, X_{1+k} } / Variance{X_1}, we have

ESS(N) =  N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1}  ) ]

This function estimates the above by first estimating the auto-correlation. Since R_k must be estimated using only N - k samples, it becomes progressively noisier for larger k. For this reason, the summation over R_k should be truncated at some number filter_beyond_lag < N. This function provides two methods to perform this truncation.

  • filter_threshold -- since many MCMC methods generate chains where R_k > 0, a reasonable criterion is to truncate at the first index where the estimated auto-correlation becomes negative. This method does not estimate the ESS of super-efficient chains (where ESS > N) correctly.

  • filter_beyond_positive_pairs -- reversible MCMC chains produce an auto-correlation sequence with the property that pairwise sums of the elements of that sequence are positive [Geyer][1], i.e. R_{2k} + R_{2k + 1} > 0 for k in {0, ..., N/2}. Deviations are only possible due to noise. This method truncates the auto-correlation sequence where the pairwise sums become non-positive.

The arguments filter_beyond_lag, filter_threshold and filter_beyond_positive_pairs are filters intended to remove noisy tail terms from R_k. You can combine filter_beyond_lag with filter_threshold or filter_beyond_positive_pairs. E.g., combiningfilter_beyond_lagandfilter_beyond_positive_pairsmeans that terms are removed if they were to be filtered under thefilter_beyond_lagORfilter_beyond_positive_pairs` criteria.

This function can also compute cross-chain ESS following [Vehtari et al. (2021)][2] by specifying the cross_chain_dims argument. Cross-chain ESS takes into account the cross-chain variance to reduce the ESS in cases where the chains are not mixing well. In general, this will be a smaller number than computing the ESS for individual chains and then summing them. In an extreme case where the chains have fallen into K non-mixing modes, this function will return ESS ~ K. Even when chains are mixing well it is still preferrable to compute cross-chain ESS via this method because it will reduce the noise in the estimate of R_k, reducing the need for truncation.

states Tensor or Python structure of Tensor objects. Dimension zero should index identically distributed states.
filter_threshold Tensor or Python structure of Tensor objects. Must broadcast with state. The sequence of auto-correlations is truncated after the first appearance of a term less than filter_threshold. Setting to None means we use no threshold filter. Since |R_k| <= 1, setting to any number less than -1 has the same effect. Ignored if filter_beyond_positive_pairs is True.
filter_beyond_lag Tensor or Python structure of Tensor objects. Must be int-like and scalar valued. The sequence of auto-correlations is truncated to this length. Setting to None means we do not filter based on the size of lags.
filter_beyond_positive_pairs Python boolean. If True, only consider the initial auto-correlation sequence where the pairwise sums are positive.
cross_chain_dims An integer Tensor or a structure of integer Tensors corresponding to each state component. If a list of states is provided, then this argument should also be a list of the same length. Which dimensions of states to treat as independent chains that ESS will be summed over. If None, no summation is performed. Note this requires at least 2 chains.
validate_args Whether to add runtime checks of argument validity. If False, and arguments are incorrect, correct behavior is not guaranteed.
name String name to prepend to created ops.

ess Tensor structure parallel to states. The effective sample size of each component of states. If cross_chain_dims is None, the shape will be states.shape[1:]. Otherwise, the shape is tf.reduce_mean(states, cross_chain_dims).shape[1:].

ValueError If states and filter_threshold or states and filter_beyond_lag are both structures of different shapes.
ValueError If cross_chain_dims is not None and there are less than 2 chains.

Examples

We use ESS to estimate standard error.

from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfd = tfp.distributions

target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.])

# Get 1000 states from one chain.
states = tfp.mcmc.sample_chain(
    num_burnin_steps=200,
    num_results=1000,
    current_state=tf.constant([0., 0.]),
    trace_fn=None,
    kernel=tfp.mcmc.HamiltonianMonteCarlo(
      target_log_prob_fn=target.log_prob,
      step_size=0.05,
      num_leapfrog_steps=20))
print(states.shape)
==> (1000, 2)

ess = effective_sample_size(states, filter_beyond_positive_pairs=True)
print(ess.shape)
==> (2,)

mean, variance = tf.nn.moments(states, axes=0)
standard_error = tf.sqrt(variance / ess)

References

[1]: Charles J. Geyer, Practical Markov chain Monte Carlo (with discussion). Statistical Science, 7:473-511, 1992.

[2]: Aki Vehtari, Andrew Gelman, Daniel Simpson, Bob Carpenter, Paul-Christian Bürkner. Rank-normalization, folding, and localization: An improved R-hat for assessing convergence of MCMC, 2021. Bayesian analysis, 16(2):667-718.