tfp.experimental.mcmc.sample_snaper_hmc

Generates samples using SNAPER HMC [1] with step size adaptation.

This utility function generates samples from a probabilistic model using SNAPERHamiltonianMonteCarlo kernel combined with DualAveragingStepSizeAdaptation kernel. The model argument can either be an instance of tfp.distributions.Distribution or a callable that computes the target log-density. In the latter case, it is also necessary to specify event_space_bijector, event_dtype and event_shape (these are inferred if model is a distribution instance).

This function can accept a structure of tfp.experimental.mcmc.Reducers, which allow computing streaming statitics with minimal memory usage. The reducers only incorporate samples after the burnin period.

By default, this function traces the following quantities:

  • The chain state.
  • A dict of auxiliary information, using keys from ArviZ [2].
    • step_size: Float scalar Tensor. HMC step size.
    • n_steps: Int Tensor. Number of HMC leapfrog steps.
    • tune: Bool Tensor. Whether this step is part of the burnin.
    • max_trajectory_length: Float Tensor. Maximum HMC trajectory length.
    • variance_scaling: List of float Tensors. The diagonal variance of the unconstrained state, used as the mass matrix.
    • diverging: Bool Tensor. Whether the sampler is divering.
    • accept_ratio: Float Tensor. Probability of acceptance of the proposal for this step.
    • is_accepted: Bool `Tensor. Whether this step is a result of an accepted proposal.

It is possible to trace nothing at all, and rely on the reducers to compute the necessary statitiscs.

model Either an instance of tfp.distributions.Distribution or a callable that evaluates the target log-density at a batch of chain states.
num_results Number of MCMC results to return after burnin.
reducer A structure of reducers.
trace_fn A callable with signature: (state, is_burnin, kernel_results, reducer, reducer_state) -> structure which defines what quantities to trace.
num_burnin_steps Python int. Number of burnin steps.
num_adaptation_steps Python int. Number of adaptation steps. Default: 0.9 * num_burnin_steps.
num_chains Python int. Number of chains. This can be inferred from init_state. Otherwise, this is 64 by default.
discard_burnin_steps Python bool. Whether to discard the burnin steps when returning the trace. Burning steps are never used for the reducers.
num_steps_between_results Python int. Number of steps to take between MCMC results. This acts as a multiplier on the total number of steps taken by the MCMC (burnin included). The size of the output trace tensors is not affected, but each element is produced by this many sub-steps.
init_state Structure of Tensors. Initial state of the chain. Default: num_chains worth of zeros in unconstrained space.
init_step_size Scalar float Tensor. Initial step size. Default: 1e-2 * total_num_dims ** -0.25,
event_space_bijector Bijector or a list of bijectors used to go from unconstrained to constrained space to improve MCMC mixing. Default: Either inferred from model or an identity.
event_dtype Structure of dtypes. The event dtype. Default: Inferred from model or init_state.
event_shape Structure of tuples. The event shape. Default: Inferred from model or init_state.
experimental_shard_axis_names A structure of string names indicating how members of the state are sharded.
experimental_reduce_chain_axis_names A string or list of string names indicating which named axes to average cross-chain statistics over.
dual_averaging_kwargs Keyword arguments passed into DualAveragingStepSizeAdaptation kernel. Default: {'target_accept_prob': 0.8}.
snaper_kwargs Keyword arguments passed into SNAPERHamiltonianMonteCarlo kernel. Default: {}.
seed PRNG seed; see tfp.random.sanitize_seed for details.
validate_args Python bool. When True, kernel parameters are checked for validity. When False, invalid inputs may silently render incorrect outputs.
name Python str name prefixed to Ops created by this class.

results SampleSNAPERHamiltonianMonteCarloResults.

Tuning

The defaults for this function should function well for many models, but it does provide a number of arguments for verifying sampler behavior. If there's a question of efficiency, the first thing to do is to set discard_burnin_steps=False and examine the step_size and max_trajectory_length and variance_scaling traces. A well-functioning sampler will have these quantities converge before sampling begins. If they are not converged, consider increasing num_burnin_steps, or adjusting the snaper_kwargs to tune SNAPER more.

Examples

Here we sample from a simple model while performing a reduction.

num_dims = 8

eigenvalues = np.exp(np.linspace(0., 3., num_dims))
q, r = np.linalg.qr(np.random.randn(num_dims, num_dims))
q *= np.sign(np.diag(r))
covariance = (q * eigenvalues).dot(q.T).astype(self.dtype)

gaussian = tfd.MultivariateNormalTriL(
    loc=tf.zeros(num_dims, self.dtype),
    scale_tril=tf.linalg.cholesky(covariance),
)

@tf.function(jit_compile=True)
def run():
  results = tfp.experimental.mcmc.sample_snaper_hmc(
      model=gaussian,
      num_results=500,
      reducer=tfp.experimental.mcmc.PotentialScaleReductionReducer(),
  )

  return results.trace, results.reduction_results

(chain, trace), potential_scale_reduction = run(tfp.random.sanitize_seed(0))

# Compute sampler diagnostics.

# Should be high (at least 100-1000).
tfp.mcmc.effective_sample_size(chain, cross_chain_dims=1)
# Should be close to 1.
potential_scale_reduction

# Compute downstream statistics.

# Should be close to np.diag(covariance)
tf.math.reduce_variance(chain, [0, 1])

References

[1]: Sountsov, P. & Hoffman, M. (2021). Focusing on Difficult Directions for Learning HMC Trajectory Lengths. <https://arxiv.org/abs/2110.11576>

[2]: Kumar, R., Carroll, C., Hartikainen, A., & Martin, O. (2019). ArviZ a unified library for exploratory analysis of Bayesian models in Python. Journal of Open Source Software, 4(33), 1143.