|  View source on GitHub | 
Generates samples using SNAPER HMC [1] with step size adaptation.
tfp.experimental.mcmc.sample_snaper_hmc(
    model,
    num_results,
    reducer=None,
    trace_fn=default_snaper_trace_fn,
    num_burnin_steps=1000,
    num_adaptation_steps=None,
    num_chains=None,
    discard_burnin_steps=True,
    num_steps_between_results=0,
    init_state=None,
    init_step_size=None,
    event_space_bijector=None,
    event_dtype=None,
    event_shape=None,
    experimental_shard_axis_names=None,
    experimental_reduce_chain_axis_names=None,
    dual_averaging_kwargs=None,
    snaper_kwargs=None,
    seed=None,
    validate_args=False,
    name='snaper_hmc'
)
This utility function generates samples from a probabilistic model using
SNAPERHamiltonianMonteCarlo kernel combined with
DualAveragingStepSizeAdaptation kernel. The model argument can either be
an instance of tfp.distributions.Distribution or a callable that computes
the target log-density. In the latter case, it is also necessary to specify
event_space_bijector, event_dtype and event_shape (these are inferred if
model is a distribution instance).
This function can accept a structure of tfp.experimental.mcmc.Reducers,
which allow computing streaming statitics with minimal memory usage. The
reducers only incorporate samples after the burnin period.
By default, this function traces the following quantities:
- The chain state.
- A dict of auxiliary information, using keys from ArviZ [2].
- step_size: Float scalar Tensor. HMC step size.
- n_steps: Int Tensor. Number of HMC leapfrog steps.
- tune: Bool Tensor. Whether this step is part of the burnin.
- max_trajectory_length: Float Tensor. Maximum HMC trajectory length.
- variance_scaling: List of float Tensors. The diagonal variance of the unconstrained state, used as the mass matrix.
- diverging: Bool Tensor. Whether the sampler is divering.
- accept_ratio: Float Tensor. Probability of acceptance of the proposal for this step.
- is_accepted: Bool `Tensor. Whether this step is a result of an accepted proposal.
 
- step_size: Float scalar 
It is possible to trace nothing at all, and rely on the reducers to compute the necessary statitiscs.
| Args | |
|---|---|
| model | Either an instance of tfp.distributions.Distributionor a callable
that evaluates the target log-density at a batch of chain states. | 
| num_results | Number of MCMC results to return after burnin. | 
| reducer | A structure of reducers. | 
| trace_fn | A callable with signature: (state, is_burnin, kernel_results,
reducer, reducer_state) -> structurewhich defines what quantities to
trace. | 
| num_burnin_steps | Python int. Number of burnin steps. | 
| num_adaptation_steps | Python int. Number of adaptation steps. Default:0.9 * num_burnin_steps. | 
| num_chains | Python int. Number of chains. This can be inferred frominit_state. Otherwise, this is 64 by default. | 
| discard_burnin_steps | Python bool. Whether to discard the burnin steps
when returning the trace. Burning steps are never used for the reducers. | 
| num_steps_between_results | Python int. Number of steps to take between
MCMC results. This acts as a multiplier on the total number of steps taken
by the MCMC (burnin included). The size of the output trace tensors is not
affected, but each element is produced by this many sub-steps. | 
| init_state | Structure of Tensors. Initial state of the chain. Default:num_chainsworth of zeros in unconstrained space. | 
| init_step_size | Scalar float Tensor. Initial step size. Default:1e-2 *
total_num_dims ** -0.25, | 
| event_space_bijector | Bijector or a list of bijectors used to go from
unconstrained to constrained space to improve MCMC mixing. Default: Either
inferred from modelor an identity. | 
| event_dtype | Structure of dtypes. The event dtype. Default: Inferred from modelorinit_state. | 
| event_shape | Structure of tuples. The event shape. Default: Inferred from modelorinit_state. | 
| experimental_shard_axis_names | A structure of string names indicating how members of the state are sharded. | 
| experimental_reduce_chain_axis_names | A string or list of string names indicating which named axes to average cross-chain statistics over. | 
| dual_averaging_kwargs | Keyword arguments passed into DualAveragingStepSizeAdaptationkernel. Default:{'target_accept_prob':
0.8}. | 
| snaper_kwargs | Keyword arguments passed into SNAPERHamiltonianMonteCarlokernel. Default:{}. | 
| seed | PRNG seed; see tfp.random.sanitize_seedfor details. | 
| validate_args | Python bool. WhenTrue, kernel parameters are checked
for validity. WhenFalse, invalid inputs may silently render incorrect
outputs. | 
| name | Python strname prefixed to Ops created by this class. | 
| Returns | |
|---|---|
| results | SampleSNAPERHamiltonianMonteCarloResults. | 
Tuning
The defaults for this function should function well for many models, but it
does provide a number of arguments for verifying sampler behavior. If there's
a question of efficiency, the first thing to do is to set
discard_burnin_steps=False and examine the step_size and
max_trajectory_length and variance_scaling traces. A well-functioning
sampler will have these quantities converge before sampling begins. If they
are not converged, consider increasing num_burnin_steps, or adjusting the
snaper_kwargs to tune SNAPER more.
Examples
Here we sample from a simple model while performing a reduction.
num_dims = 8
eigenvalues = np.exp(np.linspace(0., 3., num_dims))
q, r = np.linalg.qr(np.random.randn(num_dims, num_dims))
q *= np.sign(np.diag(r))
covariance = (q * eigenvalues).dot(q.T).astype(self.dtype)
gaussian = tfd.MultivariateNormalTriL(
    loc=tf.zeros(num_dims, self.dtype),
    scale_tril=tf.linalg.cholesky(covariance),
)
@tf.function(jit_compile=True)
def run():
  results = tfp.experimental.mcmc.sample_snaper_hmc(
      model=gaussian,
      num_results=500,
      reducer=tfp.experimental.mcmc.PotentialScaleReductionReducer(),
  )
  return results.trace, results.reduction_results
(chain, trace), potential_scale_reduction = run(tfp.random.sanitize_seed(0))
# Compute sampler diagnostics.
# Should be high (at least 100-1000).
tfp.mcmc.effective_sample_size(chain, cross_chain_dims=1)
# Should be close to 1.
potential_scale_reduction
# Compute downstream statistics.
# Should be close to np.diag(covariance)
tf.math.reduce_variance(chain, [0, 1])
References
[1]: Sountsov, P. & Hoffman, M. (2021). Focusing on Difficult Directions for Learning HMC Trajectory Lengths. <https://arxiv.org/abs/2110.11576>
[2]: Kumar, R., Carroll, C., Hartikainen, A., & Martin, O. (2019). ArviZ a unified library for exploratory analysis of Bayesian models in Python. Journal of Open Source Software, 4(33), 1143.