Adapt and sample from a joint distribution, conditioned on pins.
tfp.experimental.mcmc.windowed_adaptive_hmc(
n_draws,
joint_dist,
*,
num_leapfrog_steps,
n_chains=64,
num_adaptation_steps=500,
current_state=None,
init_step_size=None,
dual_averaging_kwargs=None,
trace_fn=default_hmc_trace_fn,
return_final_kernel_results=False,
discard_tuning=True,
chain_axis_names=None,
seed=None,
**pins
)
This uses Hamiltonian Monte Carlo to do the sampling. Step size is tuned using
a dual-averaging adaptation, and the kernel is conditioned using a diagonal
mass matrix, which is estimated using expanding windows.
Args |
n_draws
|
int
Number of draws after adaptation.
|
joint_dist
|
tfd.JointDistribution
A joint distribution to sample from.
|
num_leapfrog_steps
|
int
Number of leapfrog steps to use for the Hamiltonian Monte Carlo step.
|
n_chains
|
int or list of ints
Number of independent chains to run MCMC with.
|
num_adaptation_steps
|
int
Number of draws used to adapt step size and mass matrix.
|
current_state
|
Optional
Structure of tensors at which to initialize sampling. Should have the
same shape and structure as
model.experimental_pin(**pins).sample(n_chains) .
|
init_step_size
|
Optional
Where to initialize the step size for the leapfrog integrator. The
structure should broadcast with current_state . For example, if the
initial state is
{'a': tf.zeros(n_chains),
'b': tf.zeros([n_chains, n_features])}
```
then any of `1.`, `{'a': 1., 'b': 1.}`, or
`{'a': tf.ones(n_chains), 'b': tf.ones([n_chains, n_features])}` will
work. Defaults to the dimension of the log density to the 0.25 power.
</td>
</tr><tr>
<td>
`dual_averaging_kwargs`<a id="dual_averaging_kwargs"></a>
</td>
<td>
Optional dict
Keyword arguments to pass to <a href="../../../tfp/mcmc/DualAveragingStepSizeAdaptation"><code>tfp.mcmc.DualAveragingStepSizeAdaptation</code></a>.
By default, a `target_accept_prob` of 0.75 is set, acceptance
probabilities across chains are reduced using a harmonic mean, and the
class defaults are used otherwise.
</td>
</tr><tr>
<td>
`trace_fn`<a id="trace_fn"></a>
</td>
<td>
Optional callable
The trace function should accept the arguments
`(state, bijector, is_adapting, phmc_kernel_results)`, where the `state`
is an unconstrained, flattened float tensor, `bijector` is the
`tfb.Bijector` that is used for unconstraining and flattening,
`is_adapting` is a boolean to mark whether the draw is from an adaptation
step, and `phmc_kernel_results` is the
`UncalibratedPreconditionedHamiltonianMonteCarloKernelResults` from the
`PreconditionedHamiltonianMonteCarlo` kernel. Note that
`bijector.inverse(state)` will provide access to the current draw in the
untransformed space, using the structure of the provided `joint_dist`.
</td>
</tr><tr>
<td>
`return_final_kernel_results`<a id="return_final_kernel_results"></a>
</td>
<td>
If `True`, then the final kernel results are
returned alongside the chain state and the trace specified by the
`trace_fn`.
</td>
</tr><tr>
<td>
`discard_tuning`<a id="discard_tuning"></a>
</td>
<td>
bool
Whether to return tuning traces and draws.
</td>
</tr><tr>
<td>
`chain_axis_names`<a id="chain_axis_names"></a>
</td>
<td>
A `str` or list of `str`s indicating the named axes
by which multiple chains are sharded. See <a href="../../../tfp/experimental/mcmc/Sharded"><code>tfp.experimental.mcmc.Sharded</code></a>
for more context.
</td>
</tr><tr>
<td>
`seed`<a id="seed"></a>
</td>
<td>
PRNG seed; see <a href="../../../tfp/random/sanitize_seed"><code>tfp.random.sanitize_seed</code></a> for details.
</td>
</tr><tr>
<td>
`**pins`<a id="**pins"></a>
</td>
<td>
These are used to condition the provided joint distribution, and are
passed directly to `joint_dist.experimental_pin(**pins)`.
</td>
</tr>
</table>
<!-- Tabular view -->
<table class="responsive fixed orange">
<colgroup><col width="214px"><col></colgroup>
<tr><th colspan="2"><h2 class="add-link">Returns</h2></th></tr>
<tr class="alt">
<td colspan="2">
A single structure of draws is returned in case the trace_fn is `None`, and
`return_final_kernel_results` is `False`. If there is a trace function,
the return value is a tuple, with the trace second. If the
`return_final_kernel_results` is `True`, the return value is a tuple of
length 3, with final kernel results returned last. If `discard_tuning` is
`True`, the tensors in `draws` and `trace` will have length `n_draws`,
otherwise, they will have length `n_draws + num_adaptation_steps`.
</td>
</tr>
</table>
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[null,null,["Last updated 2023-11-21 UTC."],[],[]]
|