tfp.experimental.mcmc.KernelBuilder

Convenience constructor for common MCMC transition kernels.

KernelBuilder gives an alternative interface for building MCMC transition kernels. It wraps the base tfp.mcmc library, offering more convenience at the cost of some power and flexibility.

It is designed to work in conjunction with KernelOutputs for more convenience.

Example usage:

# Initialize builder with `target_log_prob_fn`.
builder = KernelBuilder(target_log_prob_fn)
# Configure initial transition kernel.
builder = (
  builder
  # Use Hamilton Monte Carlo
  .hmc(num_leapfrog_steps=3)
  # with step size adaptation
  .dual_averaging_adaptation()
  .set_num_adaptation_steps(50)
  # and a transformation
  .transform(my_bijector))

# Do sampling...
outputs = builder.sample(num_steps, initial_state)

# Start from the previous `KernelBuilder` configuration.
builder = (
  builder
  # Continue using HMC...
  # Still use `my_bijector` transform
  # But with no step size adaptation:
  .clear_step_adapter()
  # set a static step size
  .set_step_size(outputs.new_step_size))

# More sampling starting from where we left off.
outputs2 = builder.sample(num_steps, outputs.current_state)

# Etc ...

All methods except build() and sample() return a modified copy of the builder for further method-chaining. The builder itself is immutable (namedtuple) for safe use inside of graphs.

KernelBuilder builds kernels with the following parts (in order):

  1. A core transition kernel
  2. Optional step size adaptation or replica exchange
  3. Transformating bijector
  4. Thinning
  5. Streaming reductions

The core kernels can be either HamiltonianMonteCarlo, NoUTurnSampler, PreconditionedHamiltonianMonteCarlo, MetropolisAdjustedLangevinAlgorithm, or RandomWalkMetropolis. Support for additional core kernels may be added in the future.

Step size adaption is performed by SimpleStepSizeAdaptation or DualAveragingStepSizeAdaptation. Note not all core kernels are currently compatible with step size adaptation.

KernelBuilder maintains some state between kernel builds which can be reused or overriden:

  1. Target log prob function
  2. Core kernel class
  3. Step size (initial)
  4. Step size adapter class (optional) 4a. Number of adaptation steps 4b. Target acceptance probability
  5. Replica exchange parameters (optional)
  6. TransformedTransitionKernel bijector/params (optional)
  7. Thinning: number of steps between results (optional)
  8. Tracing parameters for TracingReducer / auto-tracing.
  9. Show progress boolean.
  10. Reductions for WithReductions

See instance method documentation for more information.

target_log_prob_fn A namedtuple alias for field number 0
core_class A namedtuple alias for field number 1
core_params A namedtuple alias for field number 2
default_step_size_on A namedtuple alias for field number 3
step_size A namedtuple alias for field number 4
step_adapter_class A namedtuple alias for field number 5
step_adapter_params A namedtuple alias for field number 6
default_target_accept_prob_on A namedtuple alias for field number 7
target_accept_prob A namedtuple alias for field number 8
replica_exchange_params A namedtuple alias for field number 9
transform_params A namedtuple alias for field number 10
num_steps_between_results A namedtuple alias for field number 11
auto_tracing_on A namedtuple alias for field number 12
tracing_params A namedtuple alias for field number 13
show_progress A namedtuple alias for field number 14
user_reducer A namedtuple alias for field number 15

Methods

build

View source

Build and return the specified kernel.

Args
num_steps An integer. Some kernel pieces (step adaptation) require knowing the number of steps to sample in advance; pass that in here.

Returns
kernel The configured TransitionKernel.

clear_reducer

View source

Remove previously set reductions.

clear_replica_exchange

View source

clear_step_adapter

View source

Removes step adaptation.

clear_tracing

View source

Remove TracingReducer.

clear_transform

View source

Remove previously set TransformedTransitionKernel.

dual_averaging_adaptation

View source

Use DualAveragingStepSizeAdaptation.

See the DualAveragingStepSizeAdaptation docs for more details.

Args
exploration_shrinkage Floating point scalar Tensor. How strongly the exploration rate is biased towards the shrinkage target.
shrinkage_target Tensor or list of tensors. Value the exploration step size(s) is/are biased towards. As num_adaptation_steps --> infinity, this bias goes to zero. Defaults to 10 times the initial step size.
step_count_smoothing Int32 scalar Tensor. Number of "pseudo-steps" added to the number of steps taken to prevents noisy exploration during the early samples.
decay_rate Floating point scalar Tensor. How much to favor recent iterations over earlier ones. A value of 1 gives equal weight to all history. A value of 0 gives weight only to the most recent iteration.
step_size_setter_fn A callable with the signature (kernel_results, new_step_size) -> new_kernel_results where kernel_results are the results of the inner_kernel, new_step_size is a Tensor or a nested collection of Tensors with the same structure as returned by the step_size_getter_fn, and new_kernel_results are a copy of kernel_results with the step size(s) set.
step_size_getter_fn A callable with the signature (kernel_results) -> step_size where kernel_results are the results of the inner_kernel, and step_size is a floating point Tensor or a nested collection of such Tensors.
log_accept_prob_getter_fn A callable with the signature (kernel_results) -> log_accept_prob where kernel_results are the results of the inner_kernel, and log_accept_prob is a floating point Tensor. log_accept_prob can either be a scalar, or have shape [num_chains]. If it's the latter, step_size should also have the same leading dimension.
validate_args Python bool. When True kernel parameters are checked for validity. When False invalid inputs may silently render incorrect outputs.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'dual_averaging_step_size_adaptation').

Returns
self Returns the builder for more method chaining.

get_step_size

View source

Return the set or default step size.

get_target_accept_prob

View source

Return the set target_accept_prob or the default for the core kernel.

hmc

View source

Use the HamiltonianMonteCarlo core transition kernel.

See the HamiltonianMonteCarlo docs for more details.

Args
num_leapfrog_steps Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to step_size * num_leapfrog_steps.
state_gradients_are_stopped Python bool indicating that the proposed new state be run through tf.stop_gradient. This is particularly useful when combining optimization over samples from the HMC chain. Default value: False (i.e., do not apply stop_gradient).
store_parameters_in_results If True, then step_size and num_leapfrog_steps are written to and read from eponymous fields in the kernel results objects returned from one_step and bootstrap_results. This allows wrapper kernels to adjust those parameters on the fly.
name Python str name prefixed to Ops created by this function. Default value: None (e.g., 'hmc_kernel').

Returns
self Returns the builder for more method chaining.

make

View source

Construct a KernelBuilder with empty defaults.

mala

View source

Use the MetropolisAdjustedLangevinAlgorithm core transition kernel.

See the MetropolisAdjustedLangevinAlgorithm docs for more details.

Args
volatility_fn Python callable which takes an argument like current_state (or *current_state if it's a list) and returns volatility value at current_state. Should return a Tensor or Python list of Tensors that must broadcast with the shape of current_state Defaults to the identity function.
parallel_iterations the number of coordinates for which the gradients of the volatility matrix volatility_fn can be computed in parallel.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'mala_kernel').

Returns
self Returns the builder for more method chaining.

nuts

View source

Use the NoUTurnSampler core kernel.

See the NoUTurnSampler docs for more details.

Args
max_tree_depth Maximum depth of the tree implicitly built by NUTS. The maximum number of leapfrog steps is bounded by 2**max_tree_depth i.e. the number of nodes in a binary tree max_tree_depth nodes deep. The default setting of 10 takes up to 1024 leapfrog steps.
max_energy_diff Scaler threshold of energy differences at each leapfrog, divergence samples are defined as leapfrog steps that exceed this threshold. Default to 1000.
unrolled_leapfrog_steps The number of leapfrogs to unroll per tree expansion step. Applies a direct linear multipler to the maximum trajectory length implied by max_tree_depth. Defaults to 1.
parallel_iterations The number of iterations allowed to run in parallel. It must be a positive integer. See tf.while_loop for more details.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'nuts_kernel').

Returns
self Returns the builder for more method chaining.

phmc

View source

Use the PreconditionedHamiltonianMonteCarlo core transition kernel.

See the PreconditionedHamiltonianMonteCarlo docs for more details.

Args
num_leapfrog_steps Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to step_size * num_leapfrog_steps.
momentum_distribution A tfp.distributions.Distribution instance to draw momentum from. Defaults to isotropic normal distributions.
state_gradients_are_stopped Python bool indicating that the proposed new state be run through tf.stop_gradient. This is particularly useful when combining optimization over samples from the HMC chain. Default value: False (i.e., do not apply stop_gradient).
store_parameters_in_results If True, then step_size, momentum_distribution, and num_leapfrog_steps are written to and read from eponymous fields in the kernel results objects returned from one_step and bootstrap_results. This allows wrapper kernels to adjust those parameters on the fly. In case this is True, the momentum_distribution must be a CompositeTensor. See tfp.experimental.auto_composite. This is incompatible with step_size_update_fn, which must be set to None.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'hmc_kernel').

Returns
self Returns the builder for more method chaining.

replica_exchange

View source

Use ReplicaExchangeMC.

See the ReplicaExchangeMC docs for more details.

Args
inverse_temperatures Tensor of inverse temperatures to temper each replica. The leftmost dimension is the num_replica and the second dimension through the rightmost can provide different temperature to different batch members, doing a left-justified broadcast.
swap_proposal_fn Python callable which take a number of replicas, and returns swaps, a shape [num_replica] + batch_shape Tensor, where axis 0 indexes a permutation of {0,..., num_replica-1}, designating replicas to swap.
state_includes_replicas Boolean indicating whether the leftmost dimension of each state sample should index replicas. If True, the leftmost dimension of the current_state kwarg to tfp.mcmc.sample_chain will be interpreted as indexing replicas.
validate_args Python bool, default False. When True distribution parameters are checked for validity despite possibly degrading runtime performance. When False invalid inputs may silently render incorrect outputs.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., "remc_kernel").

Raises
ValueError inverse_temperatures doesn't have statically known 1D shape.

Returns
self Returns the builder for more method chaining.

rwm

View source

Use the RandomWalkMetropolis core kernel.

See the RandomWalkMetropolis docs for more details.

Args
new_state_fn Python callable which takes a list of state parts and a seed; returns a same-type list of Tensors, each being a perturbation of the input state parts. The perturbation distribution is assumed to be a symmetric distribution centered at the input state part. Default value: None which is mapped to tfp.mcmc.random_walk_normal_fn().
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'rwm_kernel').

Returns
self Returns the builder for more method chaining.

sample

View source

Sample from the configured kernel.

Args
num_steps Integer number of Markov chain steps.
current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s).
previous_kernel_results A Tensor or a nested collection of Tensors. Warm-start for the auxiliary state needed by the given kernel. If not supplied, step_kernel will cold-start with kernel.bootstrap_results.

Returns
outputs A KernelOutputs object containing the states, trace, etc.

set_auto_tracing

View source

Add smart tracing.

set_num_steps_between_results

View source

Thin sampling by num_steps_between_results.

set_reducer

View source

Use tfp.experimental.mcmc.WithReductions.

See the WithReductions docs for more details.

Args
reducer A (possibly nested) structure of Reducers to be evaluated on the inner_kernel's samples.

Returns
self Returns the builder for more method chaining.

set_show_progress

View source

set_step_size

View source

Set the step size (for core kernels with a step size.)

set_target_accept_prob

View source

Set the target acceptance for step adaptation kernels.

Args
target_accept_prob A floating point Tensor representing desired acceptance probability. Must be a positive number less than 1. This can either be a scalar, or have shape [num_chains]. By default, this is 0.25 for RandomWalkMetropolis and 0.75 for HamiltonianMonteCarlo, MetropolisAdjustedLangevinAlgorithm and NoUTurnSampler.

Returns
self Returns the builder for more method chaining.

set_tracing

View source

Trace sampling state and results.

See the TracingReducer docs for more details.

Args
trace_fn A callable that takes in the current chain state and the previous kernel results and return a Tensor or a nested collection of Tensors that is accumulated across samples.
size Integer or scalar Tensor denoting the size of the accumulated TensorArray. If this is None (which is the default), a dynamic-shaped TensorArray will be used.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'tracing_reducer').

Returns
self Returns the builder for more method chaining.

simple_adaptation

View source

Use SimpleStepSizeAdaptation.

See the SimpleStepSizeAdaptation docs for more details.

Args
adaptation_rate Tensor representing amount to scale the current step_size.
step_size_setter_fn A callable with the signature (kernel_results, new_step_size) -> new_kernel_results where kernel_results are the results of the inner_kernel, new_step_size is a Tensor or a nested collection of Tensors with the same structure as returned by the step_size_getter_fn, and new_kernel_results are a copy of kernel_results with the step size(s) set.
step_size_getter_fn A callable with the signature (kernel_results) -> step_size where kernel_results are the results of the inner_kernel, and step_size is a floating point Tensor or a nested collection of such Tensors.
log_accept_prob_getter_fn A callable with the signature (kernel_results) -> log_accept_prob where kernel_results are the results of the inner_kernel, and log_accept_prob is a floating point Tensor. log_accept_prob can either be a scalar, or have shape [num_chains]. If it's the latter, step_size should also have the same leading dimension.
validate_args Python bool. When True kernel parameters are checked for validity. When False invalid inputs may silently render incorrect outputs.
name Python str name prefixed to Ops created by this class. Default: 'simple_step_size_adaptation'.

Returns
self Returns the builder for more method chaining.

transform

View source

Use TransformedTransitionKernel.

See TransformedTransitionKernel docs for more details.

Args
bijector tfp.distributions.Bijector or list of tfp.distributions.Bijectors. These bijectors use forward to map the inner_kernel state space to the state expected by inner_kernel.target_log_prob_fn.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., "transformed_kernel").

Returns
self Returns the builder for more method chaining.

use_default_step_size

View source

Use default step size (or not.)

use_default_target_accept_prob

View source

Use per-core class default target acceptance probability (or not.)