tfp.experimental.mcmc.sample_chain

Runs a Markov chain defined by the given TransitionKernel.

This is meant as a (more) helpful frontend to the low-level TransitionKernel-based MCMC API, supporting several main features:

  • Running a batch of multiple independent chains using SIMD parallelism
  • Tracing the history of the chains, or not tracing it to save memory
  • Computing reductions over chain history, whether it is also traced or not
  • Warm (re-)start, including auxiliary state

This function samples from a Markov chain at current_state whose stationary distribution is governed by the supplied TransitionKernel instance (kernel).

The current_state can be represented as a single Tensor or a list of Tensors which collectively represent the current state.

This function can sample from multiple chains, in parallel. Whether or not there are multiple chains is dictated by how the kernel treats its inputs. Typically, the shape of the independent chains is shape of the result of the target_log_prob_fn used by the kernel when applied to the given current_state.

This function can compute reductions over the samples in tandem with sampling, for example to return summary statistics without materializing all the samples. To request reductions, pass a Reducer object, or a nested structure of Reducer objects, as the reducer= argument.

In addition to the chain state, this function supports tracing of auxiliary variables used by the kernel, as well as intermediate values of any supplied reductions. The traced values are selected by specifying trace_fn. The trace_fn must be a callable accepting three arguments: the chain state, the kernel_results of the kernel, and the current results of the reductions, if any are supplied. The return value of trace_fn (which may be a Tensor or a nested structure of Tensors) is accumulated, such that each Tensor gains a new outmost dimension representing time in the chain history.

Since MCMC states are correlated, it is sometimes desirable to produce additional intermediate states, and then discard them, ending up with a set of states with decreased autocorrelation. See [Owen (2017)][1]. Such 'thinning' is made possible by setting num_steps_between_results > 0. The chain then takes num_steps_between_results extra steps between the steps that make it into the results, or are shown to any supplied reductions. The extra steps are never materialized, and thus do not increase memory requirements.

kernel An instance of tfp.mcmc.TransitionKernel which implements one step of the Markov chain.
num_results Integer number of (non-discarded) Markov chain draws to compute.
current_state Tensor or Python list of Tensors representing the initial state(s) of the Markov chain(s).
previous_kernel_results A Tensor or a nested collection of Tensors representing internal calculations made within the previous call to this function (or as returned by bootstrap_results).
reducer A (possibly nested) structure of Reducers to be evaluated on the kernel's samples. If no reducers are given (reducer=None), their states will not be passed to any supplied trace_fn.
previous_reducer_state A (possibly nested) structure of running states corresponding to the structure in reducer. For resuming streaming reduction computations begun in a previous run.
trace_fn A callable that takes in the current chain state, the current auxiliary kernel state, and the current result of any reducers, and returns a Tensor or a nested collection of Tensors that is then traced. If None, nothing is traced.
parallel_iterations The number of iterations allowed to run in parallel. It must be a positive integer. See tf.while_loop for more details.
seed Optional, a seed for reproducible sampling.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'mcmc_sample_chain').

result A SampleChainResults instance containing information about the sampling run. Main fields are trace, the history of outputs of trace_fn, and reduction_results, the final outputs of all supplied Reducers. See SampleChainResults for contents of other fields.