View source on GitHub |
Hamiltonian Monte Carlo, with given momentum distribution.
Inherits From: HamiltonianMonteCarlo
, TransitionKernel
tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
target_log_prob_fn,
step_size,
num_leapfrog_steps,
momentum_distribution=None,
state_gradients_are_stopped=False,
step_size_update_fn=None,
store_parameters_in_results=False,
experimental_shard_axis_names=None,
name=None
)
See tfp.mcmc.HamiltonianMonteCarlo
for details on HMC.
HMC produces samples much more efficiently if properly preconditioned. This can be done by choosing a momentum distribution with covariance equal to the inverse of the state's covariance.
Examples:
Simple chain with warm-up.
In this example we can use an estimate of the target covariance to sample efficiently with HMC.
import tensorflow as tf
import tensorflow_probability as tfp
tfed = tfp.experimental.distributions
# Suppose we have a target log prob fn, as well as an estimate of its
# covariance.
log_prob_fn = ...
cov_estimate = ...
# We want the mass matrix to be the *inverse* of the covariance estimate,
# so we can use the symmetric square root:
momentum_distribution = (
tfed.MultivariateNormalPrecisionFactorLinearOperator(
precision_factor=tf.linalg.LinearOperatorLowerTriangular(
tf.linalg.cholesky(cov_estimate),
),
precision=tf.linalg.LinearOperatorFullMatrix(cov_estimate),
)
# Run standard HMC below
num_burnin_steps = 100
num_results = 1000
adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
target_log_prob_fn=log_prob_fn,
momentum_distribution=momentum_distribution,
step_size=0.3,
num_leapfrog_steps=10),
num_adaptation_steps=int(num_burnin_steps * 0.8))
@tf.function
def run_chain_and_compute_ess():
draws = tfp.mcmc.sample_chain(
num_results,
num_burnin_steps=num_burnin_steps,
current_state=tf.zeros(3), # 3 chains.
kernel=adaptive_hmc,
trace_fn=None)
return tfp.mcmc.effective_sample_size(draws, cross_chain_dims=1)
run_chain_and_compute_ess() # Something close to 3 x 1000.
Estimate parameters of a more complicated distribution.
This demonstrates using multiple state parts, and reshaping a
tfde.MultivariateNormalPrecisionFactorLinearOperator
to use with a scalar or a non-square shape (in this case, [2, 3, 4]
).
mvn = tfd.JointDistributionSequential([
tfd.Normal(0., 0.1),
tfd.Normal(0., 10.),
tfd.Independent(tfd.Normal(tf.fill([2, 3, 4], 3.), 10.),
reinterpreted_batch_ndims=3)])
reshape_to_scalar = tfp.bijectors.Reshape(event_shape_out=[])
reshape_to_234 = tfp.bijectors.Reshape(event_shape_out=[2, 3, 4])
momentum_distribution = tfd.JointDistributionSequential([
tfd.Normal(0., 10.),
tfd.Normal(0., 0.1),
reshape_to_234(
tfde.MultivariateNormalPrecisionFactorLinearOperator(
0., tf.linalg.LinearOperatorDiag(tf.fill([24], 10.))))
])
num_burnin_steps = 100
num_results = 1000
adaptive_hmc = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
target_log_prob_fn=mvn.log_prob,
momentum_distribution=momentum_distribution,
step_size=0.3,
num_leapfrog_steps=10)
@tf.function
def run_chain_and_compute_ess():
draws = tfp.mcmc.sample_chain(
num_results,
num_burnin_steps=num_burnin_steps,
current_state=mvn.sample(),
kernel=adaptive_hmc,
trace_fn=None)
return tfp.mcmc.effective_sample_size(draws)
run_chain_and_compute_ess() # [1000, 1000, 1000 * tf.ones([2, 3, 4])]
Args | |
---|---|
target_log_prob_fn
|
Python callable which takes an argument like
current_state (or *current_state if it's a list) and returns its
(possibly unnormalized) log-density under the target distribution.
|
step_size
|
Tensor or Python list of Tensor s representing the step
size for the leapfrog integrator. Must broadcast with the shape of
current_state . Larger step sizes lead to faster progress, but
too-large step sizes make rejection exponentially more likely. When
possible, it's often helpful to match per-variable step sizes to the
standard deviations of the target distribution in each variable.
|
num_leapfrog_steps
|
Integer number of steps to run the leapfrog integrator
for. Total progress per HMC step is roughly proportional to
step_size * num_leapfrog_steps .
|
momentum_distribution
|
A tfp.distributions.Distribution instance to draw
momentum from. Defaults to normal distributions with identity
covariance.
|
state_gradients_are_stopped
|
Python bool indicating that the proposed
new state be run through tf.stop_gradient . This is particularly useful
when combining optimization over samples from the HMC chain.
Default value: False (i.e., do not apply stop_gradient ).
|
step_size_update_fn
|
Python callable taking current step_size
(typically a tf.Variable ) and kernel_results (typically
collections.namedtuple ) and returns updated step_size (Tensor s).
Default value: None (i.e., do not update step_size automatically).
|
store_parameters_in_results
|
If True , then step_size ,
momentum_distribution , and num_leapfrog_steps are written to and
read from eponymous fields in the kernel results objects returned from
one_step and bootstrap_results . This allows wrapper kernels to
adjust those parameters on the fly. In case this is True , the
momentum_distribution must be a CompositeTensor . See
tfp.experimental.auto_composite . This is incompatible with
step_size_update_fn , which must be set to None .
|
experimental_shard_axis_names
|
A structure of string names indicating how members of the state are sharded. |
name
|
Python str name prefixed to Ops created by this function.
Default value: None (i.e., 'phmc_kernel').
|
Attributes | |
---|---|
experimental_shard_axis_names
|
The shard axis names for members of the state. |
is_calibrated
|
Returns True if Markov chain converges to specified distribution.
|
name
|
|
num_leapfrog_steps
|
Returns the num_leapfrog_steps parameter.
If |
parameters
|
Return dict of __init__ arguments and their values.
|
state_gradients_are_stopped
|
|
step_size
|
Returns the step_size parameter.
If |
target_log_prob_fn
|
Methods
bootstrap_results
bootstrap_results(
init_state
)
Creates initial previous_kernel_results
using a supplied state
.
copy
copy(
**override_parameter_kwargs
)
Non-destructively creates a deep copy of the kernel.
Args | |
---|---|
**override_parameter_kwargs
|
Python String/value dictionary of
initialization arguments to override with new values.
|
Returns | |
---|---|
new_kernel
|
TransitionKernel object of same type as self ,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,
dict(self.parameters, **override_parameters_kwargs) .
|
experimental_with_shard_axes
experimental_with_shard_axes(
shard_axis_names
)
Returns a copy of the kernel with the provided shard axis names.
Args | |
---|---|
shard_axis_names
|
a structure of strings indicating the shard axis names for each component of this kernel's state. |
Returns | |
---|---|
A copy of the current kernel with the shard axis information. |
one_step
one_step(
current_state, previous_kernel_results, seed=None
)
Runs one iteration of Hamiltonian Monte Carlo.
Args | |
---|---|
current_state
|
Tensor or Python list of Tensor s representing the
current state(s) of the Markov chain(s). The first r dimensions index
independent chains, r = tf.rank(target_log_prob_fn(*current_state)) .
|
previous_kernel_results
|
collections.namedtuple containing Tensor s
representing values from previous calls to this function (or from the
bootstrap_results function.)
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
next_state
|
Tensor or Python list of Tensor s representing the state(s)
of the Markov chain(s) after taking exactly one step. Has same type and
shape as current_state .
|
kernel_results
|
collections.namedtuple of internal calculations used to
advance the chain.
|
Raises | |
---|---|
ValueError
|
if there isn't one step_size or a list with same length as
current_state .
|