ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tfp.experimental.mcmc.windowed_adaptive_hmc

<!-- Stable --> <table class="tfo-notebook-buttons tfo-api nocontent" align="left"> <td> <a target="_blank" href="https://github.com/tensorflow/probability/blob/v0.14.1/tensorflow_probability/python/experimental/mcmc/windowed_sampling.py#L728-L834"> <img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" /> View source on GitHub </a> </td> </table> Adapt and sample from a joint distribution, conditioned on pins. <pre class="devsite-click-to-copy prettyprint lang-py tfo-signature-link"> <code>tfp.experimental.mcmc.windowed_adaptive_hmc( n_draws, joint_dist, *, num_leapfrog_steps=64, n_chains=500, num_adaptation_steps=None, current_state=None, init_step_size=None, dual_averaging_kwargs=_default_hmc_trace_fn, trace_fn=False, return_final_kernel_results=True, discard_tuning=None, chain_axis_names=None, seed=None, **pins ) </code></pre> <!-- Placeholder for "Used in" --> This uses Hamiltonian Monte Carlo to do the sampling. Step size is tuned using a dual-averaging adaptation, and the kernel is conditioned using a diagonal mass matrix, which is estimated using expanding windows. <!-- Tabular view --> <table class="responsive fixed orange"> <colgroup><col width="214px"><col></colgroup> <tr><th colspan="2"><h2 class="add-link">Args</h2></th></tr> <tr> <td> `n_draws` </td> <td> int Number of draws after adaptation. </td> </tr><tr> <td> `joint_dist` </td> <td> `tfd.JointDistribution` A joint distribution to sample from. </td> </tr><tr> <td> `num_leapfrog_steps` </td> <td> int Number of leapfrog steps to use for the Hamiltonian Monte Carlo step. </td> </tr><tr> <td> `n_chains` </td> <td> int or list of ints Number of independent chains to run MCMC with. </td> </tr><tr> <td> `num_adaptation_steps` </td> <td> int Number of draws used to adapt step size and mass matrix. </td> </tr><tr> <td> `current_state` </td> <td> Optional Structure of tensors at which to initialize sampling. Should have the same shape and structure as `model.experimental_pin(**pins).sample(n_chains)`. </td> </tr><tr> <td> `init_step_size` </td> <td> Optional Where to initialize the step size for the leapfrog integrator. The structure should broadcast with `current_state`. For example, if the initial state is

{'a': tf.zeros(n_chains), 'b': tf.zeros([n_chains, n_features])} ``then any of1.,{'a': 1., 'b': 1.}, or{'a': tf.ones(n_chains), 'b': tf.ones([n_chains, n_features])}will work. Defaults to the dimension of the log density to the 0.25 power. </td> </tr><tr> <td>dual_averaging_kwargs</td> <td> Optional dict Keyword arguments to pass to <a href="../../../tfp/mcmc/DualAveragingStepSizeAdaptation"><code>tfp.mcmc.DualAveragingStepSizeAdaptation</code></a>. By default, atarget_accept_probof 0.75 is set, acceptance probabilities across chains are reduced using a harmonic mean, and the class defaults are used otherwise. </td> </tr><tr> <td>trace_fn</td> <td> Optional callable The trace function should accept the arguments(state, bijector, is_adapting, phmc_kernel_results), where thestateis an unconstrained, flattened float tensor,bijectoris thetfb.Bijectorthat is used for unconstraining and flattening,is_adaptingis a boolean to mark whether the draw is from an adaptation step, andphmc_kernel_resultsis theUncalibratedPreconditionedHamiltonianMonteCarloKernelResultsfrom thePreconditionedHamiltonianMonteCarlokernel. Note thatbijector.inverse(state)will provide access to the current draw in the untransformed space, using the structure of the providedjoint_dist. </td> </tr><tr> <td>return_final_kernel_results</td> <td> IfTrue, then the final kernel results are returned alongside the chain state and the trace specified by thetrace_fn. </td> </tr><tr> <td>discard_tuning</td> <td> bool Whether to return tuning traces and draws. </td> </tr><tr> <td>chain_axis_names</td> <td> Astror list ofstrs indicating the named axes by which multiple chains are sharded. See <a href="../../../tfp/experimental/mcmc/Sharded"><code>tfp.experimental.mcmc.Sharded</code></a> for more context. </td> </tr><tr> <td>seed</td> <td> PRNG seed; see <a href="../../../tfp/random/sanitize_seed"><code>tfp.random.sanitize_seed</code></a> for details. </td> </tr><tr> <td>pins</td> <td> These are used to condition the provided joint distribution, and are passed directly tojoint_dist.experimental_pin(pins)`.

A single structure of draws is returned in case the trace_fn is None, and return_final_kernel_results is False. If there is a trace function, the return value is a tuple, with the trace second. If the return_final_kernel_results is True, the return value is a tuple of length 3, with final kernel results returned last. If discard_tuning is True, the tensors in draws and trace will have length n_draws, otherwise, they will have length n_draws + num_adaptation_steps.