tfp.substrates.jax.mcmc.HamiltonianMonteCarlo

Runs one step of Hamiltonian Monte Carlo.

Inherits From: TransitionKernel

Used in the notebooks

Used in the tutorials

Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm that takes a series of gradient-informed steps to produce a Metropolis proposal. This class implements one random HMC step from a given current_state. Mathematical details and derivations can be found in [Neal (2011)][1].

The one_step function can update multiple chains in parallel. It assumes that all leftmost dimensions of current_state index independent chain states (and are therefore updated independently). The output of target_log_prob_fn(*current_state) should sum log-probabilities across all event dimensions. Slices along the rightmost dimensions may have different target distributions; for example, current_state[0, :] could have a different target distribution from current_state[1, :]. These semantics are governed by target_log_prob_fn(*current_state). (The number of independent chains is tf.size(target_log_prob_fn(*current_state)).)

Examples:

Simple chain with warm-up.

In this example we sample from a standard univariate normal distribution using HMC with adaptive step size.

from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax

tf.enable_eager_execution()

# Target distribution is proportional to: `exp(-x (1 + x))`.
def unnormalized_log_prob(x):
  return -x - x**2.

# Initialize the HMC transition kernel.
num_results = int(10e3)
num_burnin_steps = int(1e3)
adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
    tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=unnormalized_log_prob,
        num_leapfrog_steps=3,
        step_size=1.),
    num_adaptation_steps=int(num_burnin_steps * 0.8))

# Run the chain (with burn-in).
@tf.function
def run_chain():
  # Run the chain (with burn-in).
  samples, is_accepted = tfp.mcmc.sample_chain(
      num_results=num_results,
      num_burnin_steps=num_burnin_steps,
      current_state=1.,
      kernel=adaptive_hmc,
      trace_fn=lambda _, pkr: pkr.inner_results.is_accepted)

  sample_mean = tf.reduce_mean(samples)
  sample_stddev = tf.math.reduce_std(samples)
  is_accepted = tf.reduce_mean(tf.cast(is_accepted, dtype=tf.float32))
  return sample_mean, sample_stddev, is_accepted

sample_mean, sample_stddev, is_accepted = run_chain()

print('mean:{:.4f}  stddev:{:.4f}  acceptance:{:.4f}'.format(
    sample_mean.numpy(), sample_stddev.numpy(), is_accepted.numpy()))
Estimate parameters of a more complicated posterior.

In this example, we'll use Monte-Carlo EM to find best-fit parameters. See [Convergence of a stochastic approximation version of the EM algorithm][2] for more details.

More precisely, we use HMC to form a chain conditioned on parameter sigma and training data { (x[i], y[i]) : i=1...n }. Then we use one gradient step of maximum-likelihood to improve the sigma estimate. Then repeat the process until convergence. (This procedure is a Robbins--Monro algorithm.)

The generative assumptions are:

  W ~ MVN(loc=0, scale=sigma * eye(dims))
  for i=1...num_samples:
      X[i] ~ MVN(loc=0, scale=eye(dims))
    eps[i] ~ Normal(loc=0, scale=1)
      Y[i] = X[i].T * W + eps[i]

We now implement a stochastic approximation of Expectation Maximization (SAEM) using tensorflow_probability intrinsics. [Bernard (1999)][2]

from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
import numpy as np

tf.enable_eager_execution()

tfd = tfp.distributions

def make_training_data(num_samples, dims, sigma):
  dt = np.asarray(sigma).dtype
  x = np.random.randn(dims, num_samples).astype(dt)
  w = sigma * np.random.randn(1, dims).astype(dt)
  noise = np.random.randn(num_samples).astype(dt)
  y = w.dot(x) + noise
  return y[0], x, w[0]

def make_weights_prior(dims, log_sigma):
  return tfd.MultivariateNormalDiag(
      loc=tf.zeros([dims], dtype=log_sigma.dtype),
      scale_diag=tf.math.exp(log_sigma) *
                 tf.ones([dims], dtype=log_sigma.dtype))

def make_response_likelihood(w, x):
  if w.shape.ndims == 1:
    y_bar = tf.matmul(w[tf.newaxis], x)[0]
  else:
    y_bar = tf.matmul(w, x)
  return tfd.Normal(loc=y_bar, scale=tf.ones_like(y_bar))  # [n]

# Setup assumptions.
dtype = np.float32
num_samples = 500
dims = 10
tf.random.set_seed(10014)
np.random.seed(10014)

weights_prior_true_scale = np.array(0.3, dtype)
y, x, _ = make_training_data(
    num_samples, dims, weights_prior_true_scale)

log_sigma = tf.Variable(0., dtype=dtype, name='log_sigma')

optimizer = tf.optimizers.SGD(learning_rate=0.01)

@tf.function
def mcem_iter(weights_chain_start, step_size):
  with tf.GradientTape() as tape:
    tape.watch(log_sigma)
    prior = make_weights_prior(dims, log_sigma)

    def unnormalized_posterior_log_prob(w):
      likelihood = make_response_likelihood(w, x)
      return (
          prior.log_prob(w) +
          tf.reduce_sum(likelihood.log_prob(y), axis=-1))  # [m]

    def trace_fn(_, pkr):
      return (
          pkr.inner_results.log_accept_ratio,
          pkr.inner_results.accepted_results.target_log_prob,
          pkr.inner_results.accepted_results.step_size)

    num_results = 2
    weights, (
        log_accept_ratio, target_log_prob, step_size) = tfp.mcmc.sample_chain(
        num_results=num_results,
        num_burnin_steps=0,
        current_state=weights_chain_start,
        kernel=tfp.mcmc.SimpleStepSizeAdaptation(
            tfp.mcmc.HamiltonianMonteCarlo(
                target_log_prob_fn=unnormalized_posterior_log_prob,
                num_leapfrog_steps=2,
                step_size=step_size,
                state_gradients_are_stopped=True,
            ),
            # Adapt for the entirety of the trajectory.
            num_adaptation_steps=2),
        trace_fn=trace_fn,
        seed=123)

    # We do an optimization step to propagate `log_sigma` after two HMC
    # steps to propagate `weights`.
    loss = -tf.reduce_mean(target_log_prob)

  avg_acceptance_ratio = tf.math.exp(
      tfp.math.reduce_logmeanexp(tf.minimum(log_accept_ratio, 0.)))

  optimizer.apply_gradients(
      [[tape.gradient(loss, log_sigma), log_sigma]])

  weights_prior_estimated_scale = tf.math.exp(log_sigma)
  return (weights_prior_estimated_scale, weights[-1], loss,
          step_size[-1], avg_acceptance_ratio)

num_iters = int(40)

weights_prior_estimated_scale_ = np.zeros(num_iters, dtype)
weights_ = np.zeros([num_iters + 1, dims], dtype)
loss_ = np.zeros([num_iters], dtype)
weights_[0] = np.random.randn(dims).astype(dtype)
step_size_ = 0.03

for iter_ in range(num_iters):
  [
      weights_prior_estimated_scale_[iter_],
      weights_[iter_ + 1],
      loss_[iter_],
      step_size_,
      avg_acceptance_ratio_,
  ] = mcem_iter(weights_[iter_], step_size_)
  tf.compat.v1.logging.vlog(
      1, ('iter:{:>2}  loss:{: 9.3f}  scale:{:.3f}  '
          'step_size:{:.4f}  avg_acceptance_ratio:{:.4f}').format(
              iter_, loss_[iter_], weights_prior_estimated_scale_[iter_],
              step_size_, avg_acceptance_ratio_))

# Should converge to ~0.22.
import matplotlib.pyplot as plt
plt.plot(weights_prior_estimated_scale_)
plt.ylabel('weights_prior_estimated_scale')
plt.xlabel('iteration')

References

[1]: Radford Neal. MCMC Using Hamiltonian Dynamics. Handbook of Markov Chain Monte Carlo, 2011. https://arxiv.org/abs/1206.1901

[2]: Bernard Delyon, Marc Lavielle, Eric, Moulines. Convergence of a stochastic approximation version of the EM algorithm, Ann. Statist. 27 (1999), no. 1, 94--128. https://projecteuclid.org/euclid.aos/1018031103

target_log_prob_fn Python callable which takes an argument like current_state (or *current_state if it's a list) and returns its (possibly unnormalized) log-density under the target distribution.
step_size Tensor or Python list of Tensors representing the step size for the leapfrog integrator. Must broadcast with the shape of current_state. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable.
num_leapfrog_steps Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to step_size * num_leapfrog_steps.
state_gradients_are_stopped Python bool indicating that the proposed new state be run through tf.stop_gradient. This is particularly useful when combining optimization over samples from the HMC chain. Default value: False (i.e., do not apply stop_gradient).
store_parameters_in_results If True, then step_size and num_leapfrog_steps are written to and read from eponymous fields in the kernel results objects returned from one_step and bootstrap_results. This allows wrapper kernels to adjust those parameters on the fly.
experimental_shard_axis_names A structure of string names indicating how members of the state are sharded.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'hmc_kernel').

experimental_shard_axis_names The shard axis names for members of the state.
is_calibrated Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.

name

num_leapfrog_steps Returns the num_leapfrog_steps parameter.

If store_parameters_in_results argument to the initializer was set to True, this only returns the value of the num_leapfrog_steps placed in the kernel results by the bootstrap_results method. The actual num_leapfrog_steps in that situation is governed by the previous_kernel_results argument to one_step method.

parameters Return dict of __init__ arguments and their values.
state_gradients_are_stopped

step_size Returns the step_size parameter.

If store_parameters_in_results argument to the initializer was set to True, this only returns the value of the step_size placed in the kernel results by the bootstrap_results method. The actual step size in that situation is governed by the previous_kernel_results argument to one_step method.

target_log_prob_fn

Methods

bootstrap_results

View source

Creates initial previous_kernel_results using a supplied state.

copy

View source

Non-destructively creates a deep copy of the kernel.

Args
**override_parameter_kwargs Python String/value dictionary of initialization arguments to override with new values.

Returns
new_kernel TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs).

experimental_with_shard_axes

View source

Returns a copy of the kernel with the provided shard axis names.

Args
shard_axis_names a structure of strings indicating the shard axis names for each component of this kernel's state.

Returns
A copy of the current kernel with the shard axis information.

one_step

View source

Runs one iteration of Hamiltonian Monte Carlo.

Args
current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s). The first r dimensions index independent chains, r = tf.rank(target_log_prob_fn(*current_state)).
previous_kernel_results collections.namedtuple containing Tensors representing values from previous calls to this function (or from the bootstrap_results function.)
seed PRNG seed; see tfp.random.sanitize_seed for details.

Returns
next_state Tensor or Python list of Tensors representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as current_state.
kernel_results collections.namedtuple of internal calculations used to advance the chain.

Raises
ValueError if there isn't one step_size or a list with same length as current_state.