View source on GitHub |
SNAPER-HMC without step size adaptation.
Inherits From: TransitionKernel
tfp.experimental.mcmc.SNAPERHamiltonianMonteCarlo(
target_log_prob_fn,
step_size,
num_adaptation_steps,
num_mala_steps=100,
max_leapfrog_steps=1000,
trajectory_length_adaptation_rate=0.05,
principal_component_ema_factor=8,
state_ema_factor=8,
experimental_shard_axis_names=None,
experimental_reduce_chain_axis_names=None,
preconditioned_hamiltonian_monte_carlo_kwargs=None,
gradient_based_trajectory_length_adaptation_kwargs=None,
validate_args=False,
name=None
)
This implements the SNAPER-HMC algorithm from [1], without the step size
adaptation. This kernel learns a diagonal mass matrix and the trajectory
length parameters of the Hamiltonian Monte Carlo (HMC) sampler using the
Adaptive MCMC framework [2]. As with all adaptive MCMC algorithms, this kernel
does not produce samples from the target distribution while adaptation is
engaged, so be sure to set num_adaptation_steps
parameter smaller than the
number of burnin steps.
This kernel uses the SNAPER criterion (see
tfp.experimental.mcmc.snaper_criterion
for details) which has a principal-
component parameter. This kernel learns it using a batched Oja's algorithm
with a learning rate of principal_component_ema_factor / step
where step
is the iteration number.
The mass matrix is learned using a variant of the Welford's
algorithm/Exponential Moving Average, with a decay rate set to step //
state_ema_factor / (step // state_ema_factor + 1)
.
Learning the step size is a necessary component of a good HMC sampler, but it
is not handled by this kernel. That adaptation can be provided by, for
example, tfp.mcmc.SimpleStepSizeAdaptation
or
tfp.mcmc.DualAveragingSizeAdaptation
.
To aid algorithm stability, the first few steps are taken with the number of
leapfrog steps set to 1, turning the algorithm into Metropolis Adjusted
Langevin Algorithm (MALA). This is controlled by the num_mala_steps
argument.
Unlike some classical MCMC algorithms, this algorithm behaves best when the chains are initialized with very low variance. Initializing them all at one point is recommended.
SNAPER-HMC requires at least two chains to function.
Examples
Here we apply this kernel to a target with a known covariance structure and show that it recovers the principal component and the variances.
num_dims = 8
num_burnin_steps = 1000
num_adaptation_steps = int(num_burnin_steps * 0.8)
num_results = 500
num_chains = 64
step_size = 1e-2
num_mala_steps = 100
eigenvalues = np.exp(np.linspace(0., 3., num_dims))
q, r = np.linalg.qr(np.random.randn(num_dims, num_dims))
q *= np.sign(np.diag(r))
covariance = (q * eigenvalues).dot(q.T)
_, eigs = np.linalg.eigh(covariance)
principal_component = eigs[:, -1]
gaussian = tfd.MultivariateNormalTriL(
loc=tf.zeros(num_dims),
scale_tril=tf.linalg.cholesky(covariance),
)
kernel = tfp.experimental.mcmc.SNAPERHamiltonianMonteCarlo(
gaussian.log_prob,
step_size=step_size,
num_adaptation_steps=num_adaptation_steps,
num_mala_steps=num_mala_steps,
)
kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
kernel, num_adaptation_steps=num_adaptation_steps)
def trace_fn(_, pkr):
return {
'principal_component':
unnest.get_innermost(pkr, 'ema_principal_component'),
'variance':
unnest.get_innermost(pkr, 'ema_variance'),
}
init_x = tf.zeros([num_chains, num_dims])
chain, trace = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=init_x,
kernel=kernel,
trace_fn=trace_fn)
# Close to `np.diag(covariance)`
trace['variance'][-1]
# Close to `principal_component`, up to a sign.
trace['principal_component'][-1]
# Compute sampler diagnostics.
tfp.mcmc.effective_sample_size(chain, cross_chain_dims=1)
tfp.mcmc.potential_scale_reduction(chain)
# Compute downstream statistics.
tf.reduce_mean(chain, [0, 1])
References
[1]: Sountsov, P. & Hoffman, M. (2021). Focusing on Difficult Directions for Learning HMC Trajectory Lengths. <https://arxiv.org/abs/2110.11576>
[2]: Andrieu, Christophe, Thoms, Johannes. A tutorial on adaptive MCMC. Statistics and Computing, 2008. <https://people.eecs.berkeley.edu/~jordan/sail/readings/andrieu-thoms.pdf>.
Attributes | |
---|---|
experimental_reduce_chain_axis_names
|
|
experimental_shard_axis_names
|
The shard axis names for members of the state. |
gradient_based_trajectory_length_adaptation_kwargs
|
|
is_calibrated
|
Returns True if Markov chain converges to specified distribution.
|
max_leapfrog_steps
|
|
name
|
|
num_adaptation_steps
|
|
num_mala_steps
|
|
parameters
|
|
preconditioned_hamiltonian_monte_carlo_kwargs
|
|
principal_component_ema_factor
|
|
state_ema_factor
|
|
step_size
|
|
target_log_prob_fn
|
|
trajectory_length_adaptation_rate
|
|
validate_args
|
Methods
bootstrap_results
bootstrap_results(
init_state
)
Returns an object with the same type as returned by one_step(...)[1]
.
Args | |
---|---|
init_state
|
Tensor or Python list of Tensor s representing the
initial state(s) of the Markov chain(s).
|
Returns | |
---|---|
kernel_results
|
A (possibly nested) tuple , namedtuple or list of
Tensor s representing internal calculations made within this function.
|
copy
copy(
**override_parameter_kwargs
)
Non-destructively creates a deep copy of the kernel.
Args | |
---|---|
**override_parameter_kwargs
|
Python String/value dictionary of
initialization arguments to override with new values.
|
Returns | |
---|---|
new_kernel
|
TransitionKernel object of same type as self ,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,
dict(self.parameters, **override_parameters_kwargs) .
|
experimental_with_shard_axes
experimental_with_shard_axes(
shard_axis_names
)
Returns a copy of the kernel with the provided shard axis names.
Args | |
---|---|
shard_axis_names
|
a structure of strings indicating the shard axis names for each component of this kernel's state. |
Returns | |
---|---|
A copy of the current kernel with the shard axis information. |
one_step
one_step(
current_state, previous_kernel_results, seed=None
)
Takes one step of the TransitionKernel.
Must be overridden by subclasses.
Args | |
---|---|
current_state
|
Tensor or Python list of Tensor s representing the
current state(s) of the Markov chain(s).
|
previous_kernel_results
|
A (possibly nested) tuple , namedtuple or
list of Tensor s representing internal calculations made within the
previous call to this function (or as returned by bootstrap_results ).
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
next_state
|
Tensor or Python list of Tensor s representing the
next state(s) of the Markov chain(s).
|
kernel_results
|
A (possibly nested) tuple , namedtuple or list of
Tensor s representing internal calculations made within this function.
|