View source on GitHub |
Generates samples using SNAPER HMC [1] with step size adaptation.
tfp.experimental.mcmc.sample_snaper_hmc(
model,
num_results,
reducer=None,
trace_fn=default_snaper_trace_fn,
num_burnin_steps=1000,
num_adaptation_steps=None,
num_chains=None,
discard_burnin_steps=True,
num_steps_between_results=0,
init_state=None,
init_step_size=None,
event_space_bijector=None,
event_dtype=None,
event_shape=None,
experimental_shard_axis_names=None,
experimental_reduce_chain_axis_names=None,
dual_averaging_kwargs=None,
snaper_kwargs=None,
seed=None,
validate_args=False,
name='snaper_hmc'
)
This utility function generates samples from a probabilistic model using
SNAPERHamiltonianMonteCarlo
kernel combined with
DualAveragingStepSizeAdaptation
kernel. The model
argument can either be
an instance of tfp.distributions.Distribution
or a callable that computes
the target log-density. In the latter case, it is also necessary to specify
event_space_bijector
, event_dtype
and event_shape
(these are inferred if
model
is a distribution instance).
This function can accept a structure of tfp.experimental.mcmc.Reducer
s,
which allow computing streaming statitics with minimal memory usage. The
reducers only incorporate samples after the burnin period.
By default, this function traces the following quantities:
- The chain state.
- A dict of auxiliary information, using keys from ArviZ [2].
- step_size: Float scalar
Tensor
. HMC step size. - n_steps: Int
Tensor
. Number of HMC leapfrog steps. - tune: Bool
Tensor
. Whether this step is part of the burnin. - max_trajectory_length: Float
Tensor
. Maximum HMC trajectory length. - variance_scaling: List of float
Tensor
s. The diagonal variance of the unconstrained state, used as the mass matrix. - diverging: Bool
Tensor
. Whether the sampler is divering. - accept_ratio: Float
Tensor
. Probability of acceptance of the proposal for this step. - is_accepted: Bool `Tensor. Whether this step is a result of an accepted proposal.
- step_size: Float scalar
It is possible to trace nothing at all, and rely on the reducers to compute the necessary statitiscs.
Args | |
---|---|
model
|
Either an instance of tfp.distributions.Distribution or a callable
that evaluates the target log-density at a batch of chain states.
|
num_results
|
Number of MCMC results to return after burnin. |
reducer
|
A structure of reducers. |
trace_fn
|
A callable with signature: (state, is_burnin, kernel_results,
reducer, reducer_state) -> structure which defines what quantities to
trace.
|
num_burnin_steps
|
Python int . Number of burnin steps.
|
num_adaptation_steps
|
Python int . Number of adaptation steps. Default:
0.9 * num_burnin_steps .
|
num_chains
|
Python int . Number of chains. This can be inferred from
init_state . Otherwise, this is 64 by default.
|
discard_burnin_steps
|
Python bool . Whether to discard the burnin steps
when returning the trace. Burning steps are never used for the reducers.
|
num_steps_between_results
|
Python int . Number of steps to take between
MCMC results. This acts as a multiplier on the total number of steps taken
by the MCMC (burnin included). The size of the output trace tensors is not
affected, but each element is produced by this many sub-steps.
|
init_state
|
Structure of Tensor s. Initial state of the chain. Default:
num_chains worth of zeros in unconstrained space.
|
init_step_size
|
Scalar float Tensor . Initial step size. Default: 1e-2 *
total_num_dims ** -0.25 ,
|
event_space_bijector
|
Bijector or a list of bijectors used to go from
unconstrained to constrained space to improve MCMC mixing. Default: Either
inferred from model or an identity.
|
event_dtype
|
Structure of dtypes. The event dtype. Default: Inferred from
model or init_state .
|
event_shape
|
Structure of tuples. The event shape. Default: Inferred from
model or init_state .
|
experimental_shard_axis_names
|
A structure of string names indicating how members of the state are sharded. |
experimental_reduce_chain_axis_names
|
A string or list of string names indicating which named axes to average cross-chain statistics over. |
dual_averaging_kwargs
|
Keyword arguments passed into
DualAveragingStepSizeAdaptation kernel. Default: {'target_accept_prob':
0.8} .
|
snaper_kwargs
|
Keyword arguments passed into SNAPERHamiltonianMonteCarlo
kernel. Default: {} .
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
validate_args
|
Python bool . When True , kernel parameters are checked
for validity. When False , invalid inputs may silently render incorrect
outputs.
|
name
|
Python str name prefixed to Ops created by this class.
|
Returns | |
---|---|
results
|
SampleSNAPERHamiltonianMonteCarloResults .
|
Tuning
The defaults for this function should function well for many models, but it
does provide a number of arguments for verifying sampler behavior. If there's
a question of efficiency, the first thing to do is to set
discard_burnin_steps=False
and examine the step_size
and
max_trajectory_length
and variance_scaling
traces. A well-functioning
sampler will have these quantities converge before sampling begins. If they
are not converged, consider increasing num_burnin_steps
, or adjusting the
snaper_kwargs
to tune SNAPER more.
Examples
Here we sample from a simple model while performing a reduction.
num_dims = 8
eigenvalues = np.exp(np.linspace(0., 3., num_dims))
q, r = np.linalg.qr(np.random.randn(num_dims, num_dims))
q *= np.sign(np.diag(r))
covariance = (q * eigenvalues).dot(q.T).astype(self.dtype)
gaussian = tfd.MultivariateNormalTriL(
loc=tf.zeros(num_dims, self.dtype),
scale_tril=tf.linalg.cholesky(covariance),
)
@tf.function(jit_compile=True)
def run():
results = tfp.experimental.mcmc.sample_snaper_hmc(
model=gaussian,
num_results=500,
reducer=tfp.experimental.mcmc.PotentialScaleReductionReducer(),
)
return results.trace, results.reduction_results
(chain, trace), potential_scale_reduction = run(tfp.random.sanitize_seed(0))
# Compute sampler diagnostics.
# Should be high (at least 100-1000).
tfp.mcmc.effective_sample_size(chain, cross_chain_dims=1)
# Should be close to 1.
potential_scale_reduction
# Compute downstream statistics.
# Should be close to np.diag(covariance)
tf.math.reduce_variance(chain, [0, 1])
References
[1]: Sountsov, P. & Hoffman, M. (2021). Focusing on Difficult Directions for Learning HMC Trajectory Lengths. <https://arxiv.org/abs/2110.11576>
[2]: Kumar, R., Carroll, C., Hartikainen, A., & Martin, O. (2019). ArviZ a unified library for exploratory analysis of Bayesian models in Python. Journal of Open Source Software, 4(33), 1143.