init_step_size
|
Optional
Where to initialize the step size for the leapfrog integrator. The
structure should broadcast with current_state . For example, if the
initial state is
{'a': tf.zeros(n_chains),
'b': tf.zeros([n_chains, n_features])}
```
then any of `1.`, `{'a': 1., 'b': 1.}`, or
`{'a': tf.ones(n_chains), 'b': tf.ones([n_chains, n_features])}` will
work. Defaults to the dimension of the log density to the 0.25 power.
</td>
</tr><tr>
<td>
`dual_averaging_kwargs`<a id="dual_averaging_kwargs"></a>
</td>
<td>
Optional dict
Keyword arguments to pass to <a href="../../../tfp/mcmc/DualAveragingStepSizeAdaptation"><code>tfp.mcmc.DualAveragingStepSizeAdaptation</code></a>.
By default, a `target_accept_prob` of 0.85 is set, acceptance
probabilities across chains are reduced using a harmonic mean, and the
class defaults are used otherwise.
</td>
</tr><tr>
<td>
`max_tree_depth`<a id="max_tree_depth"></a>
</td>
<td>
Maximum depth of the tree implicitly built by NUTS. The
maximum number of leapfrog steps is bounded by `2**max_tree_depth` i.e.
the number of nodes in a binary tree `max_tree_depth` nodes deep. The
default setting of 10 takes up to 1024 leapfrog steps.
</td>
</tr><tr>
<td>
`max_energy_diff`<a id="max_energy_diff"></a>
</td>
<td>
Scalar threshold of energy differences at each leapfrog,
divergence samples are defined as leapfrog steps that exceed this
threshold. Default to 1000.
</td>
</tr><tr>
<td>
`unrolled_leapfrog_steps`<a id="unrolled_leapfrog_steps"></a>
</td>
<td>
The number of leapfrogs to unroll per tree
expansion step. Applies a direct linear multipler to the maximum
trajectory length implied by max_tree_depth. Defaults to 1.
</td>
</tr><tr>
<td>
`parallel_iterations`<a id="parallel_iterations"></a>
</td>
<td>
The number of iterations allowed to run in parallel.
It must be a positive integer. See <a href="https://www.tensorflow.org/api_docs/python/tf/while_loop"><code>tf.while_loop</code></a> for more details.
</td>
</tr><tr>
<td>
`trace_fn`<a id="trace_fn"></a>
</td>
<td>
Optional callable
The trace function should accept the arguments
`(state, bijector, is_adapting, phmc_kernel_results)`, where the `state`
is an unconstrained, flattened float tensor, `bijector` is the
`tfb.Bijector` that is used for unconstraining and flattening,
`is_adapting` is a boolean to mark whether the draw is from an adaptation
step, and `phmc_kernel_results` is the
`UncalibratedPreconditionedHamiltonianMonteCarloKernelResults` from the
`PreconditionedHamiltonianMonteCarlo` kernel. Note that
`bijector.inverse(state)` will provide access to the current draw in the
untransformed space, using the structure of the provided `joint_dist`.
</td>
</tr><tr>
<td>
`return_final_kernel_results`<a id="return_final_kernel_results"></a>
</td>
<td>
If `True`, then the final kernel results are
returned alongside the chain state and the trace specified by the
`trace_fn`.
</td>
</tr><tr>
<td>
`discard_tuning`<a id="discard_tuning"></a>
</td>
<td>
bool
Whether to return tuning traces and draws.
</td>
</tr><tr>
<td>
`chain_axis_names`<a id="chain_axis_names"></a>
</td>
<td>
A `str` or list of `str`s indicating the named axes
by which multiple chains are sharded. See <a href="../../../tfp/experimental/mcmc/Sharded"><code>tfp.experimental.mcmc.Sharded</code></a>
for more context.
</td>
</tr><tr>
<td>
`seed`<a id="seed"></a>
</td>
<td>
PRNG seed; see <a href="../../../tfp/random/sanitize_seed"><code>tfp.random.sanitize_seed</code></a> for details.
</td>
</tr><tr>
<td>
`**pins`<a id="**pins"></a>
</td>
<td>
These are used to condition the provided joint distribution, and are
passed directly to `joint_dist.experimental_pin(**pins)`.
</td>
</tr>
</table>
<!-- Tabular view -->
<table class="responsive fixed orange">
<colgroup><col width="214px"><col></colgroup>
<tr><th colspan="2"><h2 class="add-link">Returns</h2></th></tr>
<tr class="alt">
<td colspan="2">
A single structure of draws is returned in case the trace_fn is `None`, and
`return_final_kernel_results` is `False`. If there is a trace function,
the return value is a tuple, with the trace second. If the
`return_final_kernel_results` is `True`, the return value is a tuple of
length 3, with final kernel results returned last. If `discard_tuning` is
`True`, the tensors in `draws` and `trace` will have length `n_draws`,
otherwise, they will have length `n_draws + num_adaptation_steps`.
</td>
</tr>
</table>
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[null,null,["Last updated 2023-11-21 UTC."],[],[]]
|