tfp.substrates.jax.mcmc.SimpleStepSizeAdaptation

Adapts the inner kernel's step_size based on log_accept_prob.

Inherits From: TransitionKernel

The simple policy multiplicatively increases or decreases the step_size of the inner kernel based on the value of log_accept_prob. It is based on [equation 19 of Andrieu and Thoms (2008)][1]. Given enough steps and small enough adaptation_rate the median of the distribution of the acceptance probability will converge to the target_accept_prob. A good target acceptance probability depends on the inner kernel. If this kernel is HamiltonianMonteCarlo, then 0.6-0.9 is a good range to aim for. For RandomWalkMetropolis this should be closer to 0.25. See the individual kernels' docstrings for guidance.

In general, adaptation prevents the chain from reaching a stationary distribution, so obtaining consistent samples requires num_adaptation_steps be set to a value somewhat smaller than the number of burnin steps. However, it may sometimes be helpful to set num_adaptation_steps to a larger value during development in order to inspect the behavior of the chain during adaptation.

The step size is assumed to broadcast with the chain state, potentially having leading dimensions corresponding to multiple chains. When there are fewer of those leading dimensions than there are chain dimensions, the corresponding dimensions in the log_accept_prob are averaged (in the direct space, rather than the log space) before being used to adjust the step size. This means that this kernel can do both cross-chain adaptation, or per-chain step size adaptation, depending on the shape of the step size.

For example, if your problem has a state with shape [S], your chain state has shape [C0, C1, Y] (meaning that there are C0 * C1 total chains) and log_accept_prob has shape [C0, C1] (one acceptance probability per chain), then depending on the shape of the step size, the following will happen:

  • Step size has shape [], [S] or [1], the log_accept_prob will be averaged across its C0 and C1 dimensions. This means that you will learn a shared step size based on the mean acceptance probability across all chains. This can be useful if you don't have a lot of steps to adapt and want to average away the noise.

  • Step size has shape [C1, 1] or [C1, S], the log_accept_prob will be averaged across its C0 dimension. This means that you will learn a shared step size based on the mean acceptance probability across chains that share the coordinate across the C1 dimension. This can be useful when the C1 dimension indexes different distributions, while C0 indexes replicas of a single distribution, all sampled in parallel.

  • Step size has shape [C0, C1, 1] or [C0, C1, S], then no averaging will happen. This means that each chain will learn its own step size. This can be useful when all chains are sampling from different distributions. Even when all chains are for the same distribution, this can help during the initial warmup period.

  • Step size has shape [C0, 1, 1] or [C0, 1, S], the log_accept_prob will be averaged across its C1 dimension. This means that you will learn a shared step size based on the mean acceptance probability across chains that share the coordinate across the C0 dimension. This can be useful when the C0 dimension indexes different distributions, while C1 indexes replicas of a single distribution, all sampled in parallel.

By default, the averaging function used above is the arithmetic mean, which is not robust to stuck chains (e.g. average of one chain with p_accept = 0 and three chains with p_accept = 1 will result in an average p_accept = 0.75, which will cause this kernel keep the step size roughly the same rather than reducing it to unstick the stuck chain). A more robust choice would be to set reduce_fn argument to tfp.math.reduce_log_harmonic_mean_exp [3]. Note, however, that the harmonic mean of a set of numbers is usually smaller than the arithmetic mean, so its use will typically produce smaller than optimal step sizes even for well behaved target distributions.

Examples

from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfd = tfp.distributions

target_log_prob_fn = tfd.Normal(loc=0., scale=1.).log_prob
num_burnin_steps = 500
num_results = 500
num_chains = 64
step_size = 0.1
# Or, if you want per-chain step size:
# step_size = tf.fill([num_chains], step_size)

kernel = tfp.mcmc.HamiltonianMonteCarlo(
    target_log_prob_fn=target_log_prob_fn,
    num_leapfrog_steps=2,
    step_size=step_size)
kernel = tfp.mcmc.SimpleStepSizeAdaptation(
    inner_kernel=kernel, num_adaptation_steps=int(num_burnin_steps * 0.8))

# The chain will be stepped for num_results + num_burnin_steps, adapting for
# the first num_adaptation_steps.
samples, [step_size, log_accept_ratio] = tfp.mcmc.sample_chain(
    num_results=num_results,
    num_burnin_steps=num_burnin_steps,
    current_state=tf.zeros(num_chains),
    kernel=kernel,
    trace_fn=lambda _, pkr: [pkr.inner_results.accepted_results.step_size,
                             pkr.inner_results.log_accept_ratio])

# ~0.75
p_accept = tf.math.exp(tfp.math.reduce_logmeanexp(
    tf.minimum(log_accept_ratio, 0.)))

References

[1]: Andrieu, Christophe, Thoms, Johannes. A tutorial on adaptive MCMC. Statistics and Computing, 2008. https://people.eecs.berkeley.edu/~jordan/sail/readings/andrieu-thoms.pdf

[3]: Hoffman, M., Radul, A., & Sountsov, P. An Adaptive MCMC Scheme for Setting Trajectory Lengths in Hamiltonian Monte Carlo, 2020. In preparation.

inner_kernel TransitionKernel-like object.
num_adaptation_steps Scalar int Tensor number of initial steps to during which to adjust the step size. This may be greater, less than, or equal to the number of burnin steps.
target_accept_prob A floating point Tensor representing desired acceptance probability. Must be a positive number less than 1. This can either be a scalar, or have shape [num_chains]. Default value: 0.75 (the [center of asymptotically optimal rate for HMC][1]).
adaptation_rate Tensor representing amount to scale the current step_size.
step_size_setter_fn A callable with the signature (kernel_results, new_step_size) -> new_kernel_results where kernel_results are the results of the inner_kernel, new_step_size is a Tensor or a nested collection of Tensors with the same structure as returned by the step_size_getter_fn, and new_kernel_results are a copy of kernel_results with the step size(s) set.
step_size_getter_fn A callable with the signature (kernel_results) -> step_size where kernel_results are the results of the inner_kernel, and step_size is a floating point Tensor or a nested collection of such Tensors.
log_accept_prob_getter_fn A callable with the signature (kernel_results) -> log_accept_prob where kernel_results are the results of the inner_kernel, and log_accept_prob is a floating point Tensor. log_accept_prob can either be a scalar, or have shape [num_chains]. If it's the latter, step_size should also have the same leading dimension.
reduce_fn A callable with signature (input_tensor, axis, keepdims) -> tensor that returns a log-reduction of log_accept_prob, typically some sort of mean. By default, this performs an arithmetic mean.
experimental_reduce_chain_axis_names A str or list of strs indicating the named axes that should additionally reduced during the log-reduction of log_accept_prob.
validate_args Python bool. When True kernel parameters are checked for validity. When False invalid inputs may silently render incorrect outputs.
name Python str name prefixed to Ops created by this class. Default: 'simple_step_size_adaptation'.

experimental_reduce_chain_axis_names

experimental_shard_axis_names The shard axis names for members of the state.
inner_kernel

is_calibrated Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.

name

num_adaptation_steps

parameters Return dict of __init__ arguments and their values.

Methods

bootstrap_results

View source

Returns an object with the same type as returned by one_step(...)[1].

Args
init_state Tensor or Python list of Tensors representing the initial state(s) of the Markov chain(s).

Returns
kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function.

copy

View source

Non-destructively creates a deep copy of the kernel.

Args
**override_parameter_kwargs Python String/value dictionary of initialization arguments to override with new values.

Returns
new_kernel TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs).

experimental_with_shard_axes

View source

Returns a copy of the kernel with the provided shard axis names.

Args
shard_axis_names a structure of strings indicating the shard axis names for each component of this kernel's state.

Returns
A copy of the current kernel with the shard axis information.

log_accept_prob_getter_fn

View source

one_step

View source

Takes one step of the TransitionKernel.

Must be overridden by subclasses.

Args
current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s).
previous_kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within the previous call to this function (or as returned by bootstrap_results).
seed PRNG seed; see tfp.random.sanitize_seed for details.

Returns
next_state Tensor or Python list of Tensors representing the next state(s) of the Markov chain(s).
kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function.

reduce_fn

View source

step_size_getter_fn

View source

step_size_setter_fn

View source