View source on GitHub |
Runs one step of the slice sampler using a hit and run approach.
Inherits From: TransitionKernel
tfp.mcmc.SliceSampler(
target_log_prob_fn,
step_size,
max_doublings,
experimental_shard_axis_names=None,
name=None
)
Slice Sampling is a Markov Chain Monte Carlo (MCMC) algorithm based, as stated
by [Neal (2003)][1], on the observation that "...one can sample from a
distribution by sampling uniformly from the region under the plot of its
density function. A Markov chain that converges to this uniform distribution
can be constructed by alternately uniform sampling in the vertical direction
with uniform sampling from the horizontal slice
defined by the current
vertical position, or more generally, with some update that leaves the uniform
distribution over this slice invariant". Mathematical details and derivations
can be found in [Neal (2003)][1]. The one dimensional slice sampler is
extended to n-dimensions through use of a hit-and-run approach: choose a
random direction in n-dimensional space and take a step, as determined by the
one-dimensional slice sampling algorithm, along that direction
[Belisle at al. 1993][2].
The one_step
function can update multiple chains in parallel. It assumes
that all leftmost dimensions of current_state
index independent chain states
(and are therefore updated independently). The output of
target_log_prob_fn(*current_state)
should sum log-probabilities across all
event dimensions. Slices along the rightmost dimensions may have different
target distributions; for example, current_state[0, :]
could have a
different target distribution from current_state[1, :]
. These semantics are
governed by target_log_prob_fn(*current_state)
. (The number of independent
chains is tf.size(target_log_prob_fn(*current_state))
.)
Note that the sampler only supports states where all components have a common dtype.
Examples:
Simple chain with warm-up.
In this example we sample from a standard univariate normal distribution using slice sampling.
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
import numpy as np
dtype = np.float32
target = tfd.Normal(loc=dtype(0), scale=dtype(1))
samples = tfp.mcmc.sample_chain(
num_results=1000,
current_state=dtype(1),
kernel=tfp.mcmc.SliceSampler(
target.log_prob,
step_size=1.0,
max_doublings=5),
num_burnin_steps=500,
trace_fn=None,
seed=1234)
sample_mean = tf.reduce_mean(samples, axis=0)
sample_std = tf.sqrt(
tf.reduce_mean(
tf.math.squared_difference(samples, sample_mean),
axis=0))
print('Sample mean: ', sample_mean.numpy())
print('Sample Std: ', sample_std.numpy())
Sample from a Two Dimensional Normal.
In the following example we sample from a two dimensional Normal distribution using slice sampling.
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
import numpy as np
dtype = np.float32
true_mean = dtype([0, 0])
true_cov = dtype([[1, 0.5], [0.5, 1]])
num_results = 500
num_chains = 50
# Target distribution is defined through the Cholesky decomposition
chol = tf.linalg.cholesky(true_cov)
target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)
# Initial state of the chain
init_state = np.ones([num_chains, 2], dtype=dtype)
# Run Slice Samper for `num_results` iterations for `num_chains`
# independent chains:
@tf.function
def run_mcmc():
states = tfp.mcmc.sample_chain(
num_results=num_results,
current_state=init_state,
kernel=tfp.mcmc.SliceSampler(
target_log_prob_fn=target.log_prob,
step_size=1.0,
max_doublings=5),
num_burnin_steps=200,
num_steps_between_results=1,
trace_fn=None,
seed=47)
return states
states = run_mcmc()
sample_mean = tf.reduce_mean(states, axis=[0, 1])
z = (states - sample_mean)[..., tf.newaxis]
sample_cov = tf.reduce_mean(
tf.matmul(z, tf.transpose(z, [0, 1, 3, 2])), [0, 1])
print('sample mean', sample_mean.numpy())
print('sample covariance matrix', sample_cov.numpy())
References
[1]: Radford M. Neal. Slice Sampling. The Annals of Statistics. 2003, Vol 31, No. 3 , 705-767. https://projecteuclid.org/download/pdf_1/euclid.aos/1056562461
[2]: C.J.P. Belisle, H.E. Romeijn, R.L. Smith. Hit-and-run algorithms for generating multivariate distributions. Math. Oper. Res., 18(1993), 225-266. https://www.jstor.org/stable/3690278?seq=1#page_scan_tab_contents
Args | |
---|---|
target_log_prob_fn
|
Python callable which takes an argument like
current_state (or *current_state if it is a list) and returns its
(possibly unnormalized) log-density under the target distribution.
|
step_size
|
Scalar or tf.Tensor with same dtype as and shape compatible
with x_initial . The size of the initial interval.
|
max_doublings
|
Scalar positive int32 tf.Tensor . The maximum number of
doublings to consider.
|
experimental_shard_axis_names
|
A structure of string names indicating how members of the state are sharded. |
name
|
Python str name prefixed to Ops created by this function.
Default value: None (i.e., 'slice_sampler_kernel').
|
Attributes | |
---|---|
experimental_shard_axis_names
|
The shard axis names for members of the state. |
is_calibrated
|
Returns True if Markov chain converges to specified distribution.
|
max_doublings
|
|
name
|
|
parameters
|
Returns dict of __init__ arguments and their values.
|
step_size
|
|
target_log_prob_fn
|
Methods
bootstrap_results
bootstrap_results(
init_state
)
Returns an object with the same type as returned by one_step(...)[1]
.
Args | |
---|---|
init_state
|
Tensor or Python list of Tensor s representing the
initial state(s) of the Markov chain(s).
|
Returns | |
---|---|
kernel_results
|
A (possibly nested) tuple , namedtuple or list of
Tensor s representing internal calculations made within this function.
|
copy
copy(
**override_parameter_kwargs
)
Non-destructively creates a deep copy of the kernel.
Args | |
---|---|
**override_parameter_kwargs
|
Python String/value dictionary of
initialization arguments to override with new values.
|
Returns | |
---|---|
new_kernel
|
TransitionKernel object of same type as self ,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,
dict(self.parameters, **override_parameters_kwargs) .
|
experimental_with_shard_axes
experimental_with_shard_axes(
shard_axis_names
)
Returns a copy of the kernel with the provided shard axis names.
Args | |
---|---|
shard_axis_names
|
a structure of strings indicating the shard axis names for each component of this kernel's state. |
Returns | |
---|---|
A copy of the current kernel with the shard axis information. |
one_step
one_step(
current_state, previous_kernel_results, seed=None
)
Runs one iteration of Slice Sampler.
Args | |
---|---|
current_state
|
Tensor or Python list of Tensor s representing the
current state(s) of the Markov chain(s). The first r dimensions
index independent chains,
r = tf.rank(target_log_prob_fn(*current_state)) .
|
previous_kernel_results
|
collections.namedtuple containing Tensor s
representing values from previous calls to this function (or from the
bootstrap_results function.)
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
next_state
|
Tensor or Python list of Tensor s representing the state(s)
of the Markov chain(s) after taking exactly one step. Has same type and
shape as current_state .
|
kernel_results
|
collections.namedtuple of internal calculations used to
advance the chain.
|
Raises | |
---|---|
ValueError
|
if there isn't one step_size or a list with same length as
current_state .
|
TypeError
|
if not target_log_prob.dtype.is_floating .
|