ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


Runs one step of Metropolis-adjusted Langevin algorithm.

Inherits From: TransitionKernel

Metropolis-adjusted Langevin algorithm (MALA) is a Markov chain Monte Carlo (MCMC) algorithm that takes a step of a discretised Langevin diffusion as a proposal. This class implements one step of MALA using Euler-Maruyama method for a given current_state and diagonal preconditioning volatility matrix. Mathematical details and derivations can be found in [Roberts and Rosenthal (1998)][1] and [Xifara et al. (2013)][2].

See UncalibratedLangevin class description below for details on the proposal generating step of the algorithm.

The one_step function can update multiple chains in parallel. It assumes that all leftmost dimensions of current_state index independent chain states (and are therefore updated independently). The output of target_log_prob_fn(*current_state) should reduce log-probabilities across all event dimensions. Slices along the rightmost dimensions may have different target distributions; for example, current_state[0, :] could have a different target distribution from current_state[1, :]. These semantics are governed by target_log_prob_fn(*current_state). (The number of independent chains is tf.size(target_log_prob_fn(*current_state)).)


Simple chain with warm-up.

In this example we sample from a standard univariate normal distribution using MALA with step_size equal to 0.75.

import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
import numpy as np
import matplotlib.pyplot as plt

tfd = tfp.distributions
dtype = np.float32

# Target distribution is Standard Univariate Normal
target = tfd.Normal(loc=dtype(0), scale=dtype(1))

def target_log_prob(x):
  return target.log_prob(x)

# Define MALA sampler with `step_size` equal to 0.75
samples = tfp.mcmc.sample_chain(

sample_mean = tf.reduce_mean(samples, axis=0)
sample_std = tf.sqrt(
        tf.math.squared_difference(samples, sample_mean),

print('sample mean', sample_mean)
print('sample standard deviation', sample_std)

plt.plot(samples.numpy(), 'b')
Sample from a 3-D Multivariate Normal distribution.

In this example we also consider a non-constant volatility function.

import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
import numpy as np

dtype = np.float32
true_mean = dtype([0, 0, 0])
true_cov = dtype([[1, 0.25, 0.25], [0.25, 1, 0.25], [0.25, 0.25, 1]])
num_results = 500
num_chains = 500

# Target distribution is defined through the Cholesky decomposition
chol = tf.linalg.cholesky(true_cov)
target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)

# Here we define the volatility function to be non-constant
def volatility_fn(x):
  # Stack the input tensors together
  return 1. / (0.5 + 0.1 * tf.math.abs(x))

# Initial state of the chain
init_state = np.ones([num_chains, 3], dtype=dtype)

# Run MALA with normal proposal for `num_results` iterations for
# `num_chains` independent chains:
states = tfp.mcmc.sample_chain(

sample_mean = tf.reduce_mean(states, axis=[0, 1])
x = (states - sample_mean)[..., tf.newaxis]
sample_cov = tf.reduce_mean(
    tf.matmul(x, tf.transpose(x, [0, 1, 3, 2])), [0, 1])

print('sample mean', sample_mean.numpy())
print('sample covariance matrix', sample_cov.numpy())


[1]: Gareth Roberts and Jeffrey Rosenthal. Optimal Scaling of Discrete Approximations to Langevin Diffusions. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 60: 255-268, 1998.

[2]: T. Xifara et al. Langevin diffusions and the Metropolis-adjusted Langevin algorithm. arXiv preprint arXiv:1309.2983, 2013.

target_log_prob_fn Python callable which takes an argument like current_state (or *current_state if it's a list) and returns its (possibly unnormalized) log-density under the target distribution.
step_size Tensor or Python list of Tensors representing the step size for the leapfrog integrator. Must broadcast with the shape of current_state. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable.
volatility_fn Python callable which takes an argument like current_state (or *current_state if it's a list) and returns volatility value at current_state. Should return a Tensor or Python list of Tensors that must broadcast with the shape of current_state Defaults to the identity function.
parallel_iterations the number of coordinates for which the gradients of the volatility matrix volatility_fn can be computed in parallel. Default value: None (i.e., use system default).
experimental_shard_axis_names A structure of string names indicating how members of the state are sharded.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'mala_kernel').

ValueError if there isn't one step_size or a list with same length as current_state.
TypeError if volatility_fn is not callable.

experimental_shard_axis_names The shard axis names for members of the state.
is_calibrated Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.



parameters Return dict of __init__ arguments and their values.





View source

Creates initial previous_kernel_results using a supplied state.


View source

Non-destructively creates a deep copy of the kernel.

**override_parameter_kwargs Python String/value dictionary of initialization arguments to override with new values.

new_kernel TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs).


View source

Returns a copy of the kernel with the provided shard axis names.

shard_axis_names a structure of strings indicating the shard axis names for each component of this kernel's state.

A copy of the current kernel with the shard axis information.


View source

Runs one iteration of MALA.

current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s). The first r dimensions index independent chains, r = tf.rank(target_log_prob_fn(*current_state)).
previous_kernel_results collections.namedtuple containing Tensors representing values from previous calls to this function (or from the bootstrap_results function.)
seed PRNG seed; see tfp.random.sanitize_seed for details.

next_state Tensor or Python list of Tensors representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as current_state.
kernel_results collections.namedtuple of internal calculations used to advance the chain.

ValueError if there isn't one step_size or a list with same length as current_state or diffusion_drift.