View source on GitHub |
Adapts the inner kernel's step_size
based on log_accept_prob
.
Inherits From: TransitionKernel
tfp.substrates.jax.mcmc.SimpleStepSizeAdaptation(
inner_kernel,
num_adaptation_steps,
target_accept_prob=0.75,
adaptation_rate=0.01,
step_size_setter_fn=hmc_like_step_size_setter_fn,
step_size_getter_fn=hmc_like_step_size_getter_fn,
log_accept_prob_getter_fn=hmc_like_log_accept_prob_getter_fn,
reduce_fn=tfp.substrates.jax.math.reduce_logmeanexp
,
experimental_reduce_chain_axis_names=None,
validate_args=False,
name=None
)
The simple policy multiplicatively increases or decreases the step_size
of
the inner kernel based on the value of log_accept_prob
. It is based on
[equation 19 of Andrieu and Thoms (2008)][1]. Given enough steps and small
enough adaptation_rate
the median of the distribution of the acceptance
probability will converge to the target_accept_prob
. A good target
acceptance probability depends on the inner kernel. If this kernel is
HamiltonianMonteCarlo
, then 0.6-0.9 is a good range to aim for. For
RandomWalkMetropolis
this should be closer to 0.25. See the individual
kernels' docstrings for guidance.
In general, adaptation prevents the chain from reaching a stationary
distribution, so obtaining consistent samples requires num_adaptation_steps
be set to a value somewhat smaller than the number of burnin steps.
However, it may sometimes be helpful to set num_adaptation_steps
to a larger
value during development in order to inspect the behavior of the chain during
adaptation.
The step size is assumed to broadcast with the chain state, potentially having
leading dimensions corresponding to multiple chains. When there are fewer of
those leading dimensions than there are chain dimensions, the corresponding
dimensions in the log_accept_prob
are averaged (in the direct space, rather
than the log space) before being used to adjust the step size. This means that
this kernel can do both cross-chain adaptation, or per-chain step size
adaptation, depending on the shape of the step size.
For example, if your problem has a state with shape [S]
, your chain state
has shape [C0, C1, Y]
(meaning that there are C0 * C1
total chains) and
log_accept_prob
has shape [C0, C1]
(one acceptance probability per chain),
then depending on the shape of the step size, the following will happen:
Step size has shape [], [S] or [1], the
log_accept_prob
will be averaged across itsC0
andC1
dimensions. This means that you will learn a shared step size based on the mean acceptance probability across all chains. This can be useful if you don't have a lot of steps to adapt and want to average away the noise.Step size has shape [C1, 1] or [C1, S], the
log_accept_prob
will be averaged across itsC0
dimension. This means that you will learn a shared step size based on the mean acceptance probability across chains that share the coordinate across theC1
dimension. This can be useful when theC1
dimension indexes different distributions, whileC0
indexes replicas of a single distribution, all sampled in parallel.Step size has shape [C0, C1, 1] or [C0, C1, S], then no averaging will happen. This means that each chain will learn its own step size. This can be useful when all chains are sampling from different distributions. Even when all chains are for the same distribution, this can help during the initial warmup period.
Step size has shape [C0, 1, 1] or [C0, 1, S], the
log_accept_prob
will be averaged across itsC1
dimension. This means that you will learn a shared step size based on the mean acceptance probability across chains that share the coordinate across theC0
dimension. This can be useful when theC0
dimension indexes different distributions, whileC1
indexes replicas of a single distribution, all sampled in parallel.
By default, the averaging function used above is the arithmetic mean, which is
not robust to stuck chains (e.g. average of one chain with p_accept = 0
and
three chains with p_accept = 1
will result in an average p_accept = 0.75
,
which will cause this kernel keep the step size roughly the same rather than
reducing it to unstick the stuck chain). A more robust choice would be to set
reduce_fn
argument to tfp.math.reduce_log_harmonic_mean_exp
[3]. Note,
however, that the harmonic mean of a set of numbers is usually smaller than
the arithmetic mean, so its use will typically produce smaller than optimal
step sizes even for well behaved target distributions.
Examples
from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfd = tfp.distributions
target_log_prob_fn = tfd.Normal(loc=0., scale=1.).log_prob
num_burnin_steps = 500
num_results = 500
num_chains = 64
step_size = 0.1
# Or, if you want per-chain step size:
# step_size = tf.fill([num_chains], step_size)
kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn,
num_leapfrog_steps=2,
step_size=step_size)
kernel = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel=kernel, num_adaptation_steps=int(num_burnin_steps * 0.8))
# The chain will be stepped for num_results + num_burnin_steps, adapting for
# the first num_adaptation_steps.
samples, [step_size, log_accept_ratio] = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=tf.zeros(num_chains),
kernel=kernel,
trace_fn=lambda _, pkr: [pkr.inner_results.accepted_results.step_size,
pkr.inner_results.log_accept_ratio])
# ~0.75
p_accept = tf.math.exp(tfp.math.reduce_logmeanexp(
tf.minimum(log_accept_ratio, 0.)))
References
[1]: Andrieu, Christophe, Thoms, Johannes. A tutorial on adaptive MCMC. Statistics and Computing, 2008. https://people.eecs.berkeley.edu/~jordan/sail/readings/andrieu-thoms.pdf
[3]: Hoffman, M., Radul, A., & Sountsov, P. An Adaptive MCMC Scheme for Setting Trajectory Lengths in Hamiltonian Monte Carlo, 2020. In preparation.
Attributes | |
---|---|
experimental_reduce_chain_axis_names
|
|
experimental_shard_axis_names
|
The shard axis names for members of the state. |
inner_kernel
|
|
is_calibrated
|
Returns True if Markov chain converges to specified distribution.
|
name
|
|
num_adaptation_steps
|
|
parameters
|
Return dict of __init__ arguments and their values.
|
Methods
bootstrap_results
bootstrap_results(
init_state
)
Returns an object with the same type as returned by one_step(...)[1]
.
Args | |
---|---|
init_state
|
Tensor or Python list of Tensor s representing the
initial state(s) of the Markov chain(s).
|
Returns | |
---|---|
kernel_results
|
A (possibly nested) tuple , namedtuple or list of
Tensor s representing internal calculations made within this function.
|
copy
copy(
**override_parameter_kwargs
)
Non-destructively creates a deep copy of the kernel.
Args | |
---|---|
**override_parameter_kwargs
|
Python String/value dictionary of
initialization arguments to override with new values.
|
Returns | |
---|---|
new_kernel
|
TransitionKernel object of same type as self ,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,
dict(self.parameters, **override_parameters_kwargs) .
|
experimental_with_shard_axes
experimental_with_shard_axes(
shard_axis_names
)
Returns a copy of the kernel with the provided shard axis names.
Args | |
---|---|
shard_axis_names
|
a structure of strings indicating the shard axis names for each component of this kernel's state. |
Returns | |
---|---|
A copy of the current kernel with the shard axis information. |
log_accept_prob_getter_fn
log_accept_prob_getter_fn(
kernel_results
)
one_step
one_step(
current_state, previous_kernel_results, seed=None
)
Takes one step of the TransitionKernel.
Must be overridden by subclasses.
Args | |
---|---|
current_state
|
Tensor or Python list of Tensor s representing the
current state(s) of the Markov chain(s).
|
previous_kernel_results
|
A (possibly nested) tuple , namedtuple or
list of Tensor s representing internal calculations made within the
previous call to this function (or as returned by bootstrap_results ).
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
next_state
|
Tensor or Python list of Tensor s representing the
next state(s) of the Markov chain(s).
|
kernel_results
|
A (possibly nested) tuple , namedtuple or list of
Tensor s representing internal calculations made within this function.
|
reduce_fn
reduce_fn(
input_tensor, axis, keepdims, experimental_named_axis=None
)
step_size_getter_fn
step_size_getter_fn(
kernel_results
)
step_size_setter_fn
step_size_setter_fn(
kernel_results, new_step_size
)