In this case, "valid" is defined as the value of target_fn is
finite. This corresponds to an MCMC workflow where target_fn
compute the log-probability one wants to sample from, in which case
"finite target_fn" means "finite and positive probability state".
If target_fn returns a Tensor of size greater than 1, the results
are assumed to be independent of each other, so that different batch
members can be accepted individually.
The method is bounded rejection sampling. The bound serves to avoid
wasting computation on hopeless initialization procedures. In
interactive MCMC, one would presumably rather come up with a better
initialization proposal than wait for an unbounded number of
attempts with a bad one. If unbounded re-trials are desired,
set max_trials to None.
Args
proposal_fn
A function accepting a seed keyword argument and no other
required arguments which generates proposed initial states.
target_fn
A function accepting the return value of proposal_fn
and returning a floating-point Tensor.
*args
Additional arguments passed to proposal_fn.
max_trials
Size-1 integer Tensor or None. Maximum number of
calls to proposal_fn to attempt. If acceptable states are not
found in this many trials, retry_init signals an error. If
None, there is no limit, and retry_init skips the control
flow cost of checking for success.
Python str name prefixed to Ops created by this function.
Default value: None (i.e., 'mcmc_sample_chain').
**kwargs
Additional keyword arguments passed to proposal_fn.
Returns
states
An acceptable result from proposal_fn.
Example
One popular MCMC initialization scheme is to start the chains near 0
in unconstrained space. There are models where the unconstraining
transformation cannot exactly capture the space of valid states,
such that this initialization has some material but not overwhelming
chance of failure. In this case, we can use retry_init to compensate.
[null,null,["Last updated 2023-11-21 UTC."],[],[],null,["# tfp.experimental.mcmc.retry_init\n\n\u003cbr /\u003e\n\n|-------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/probability/blob/v0.23.0/tensorflow_probability/python/experimental/mcmc/initialization.py#L212-L299) |\n\nTries an MCMC initialization proposal until it gets a valid state. \n\n tfp.experimental.mcmc.retry_init(\n proposal_fn,\n target_fn,\n *args,\n max_trials=50,\n seed=None,\n name=None,\n **kwargs\n )\n\n### Used in the notebooks\n\n| Used in the tutorials |\n|----------------------------------------------------------------------------------------------------------------------|\n| - [TFP Release Notes notebook (0.13.0)](https://www.tensorflow.org/probability/examples/TFP_Release_Notebook_0_13_0) |\n\nIn this case, \"valid\" is defined as the value of `target_fn` is\nfinite. This corresponds to an MCMC workflow where `target_fn`\ncompute the log-probability one wants to sample from, in which case\n\"finite `target_fn`\" means \"finite and positive probability state\".\nIf `target_fn` returns a Tensor of size greater than 1, the results\nare assumed to be independent of each other, so that different batch\nmembers can be accepted individually.\n\nThe method is bounded rejection sampling. The bound serves to avoid\nwasting computation on hopeless initialization procedures. In\ninteractive MCMC, one would presumably rather come up with a better\ninitialization proposal than wait for an unbounded number of\nattempts with a bad one. If unbounded re-trials are desired,\nset `max_trials` to `None`.\n| **Note:** XLA and @jax.jit do not support assertions, so this function can return invalid states on those platforms without raising an error (unless `max_trials` is set to `None`).\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `proposal_fn` | A function accepting a `seed` keyword argument and no other required arguments which generates proposed initial states. |\n| `target_fn` | A function accepting the return value of `proposal_fn` and returning a floating-point Tensor. |\n| `*args` | Additional arguments passed to `proposal_fn`. |\n| `max_trials` | Size-1 integer `Tensor` or None. Maximum number of calls to `proposal_fn` to attempt. If acceptable states are not found in this many trials, `retry_init` signals an error. If `None`, there is no limit, and `retry_init` skips the control flow cost of checking for success. |\n| `seed` | PRNG seed; see [`tfp.random.sanitize_seed`](../../../tfp/random/sanitize_seed) for details. |\n| `name` | Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'mcmc_sample_chain'). |\n| `**kwargs` | Additional keyword arguments passed to `proposal_fn`. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|----------|------------------------------------------|\n| `states` | An acceptable result from `proposal_fn`. |\n\n\u003cbr /\u003e\n\n#### Example\n\nOne popular MCMC initialization scheme is to start the chains near 0\nin unconstrained space. There are models where the unconstraining\ntransformation cannot exactly capture the space of valid states,\nsuch that this initialization has some material but not overwhelming\nchance of failure. In this case, we can use `retry_init` to compensate. \n\n @tfp.distributions.JointDistributionCoroutine\n def model():\n ...\n\n raw_init_dist = tfp.experimental.mcmc.init_near_unconstrained_zero(model)\n init_states = tfp.experimental.mcmc.retry_init(\n proposal_fn=raw_init_dist.sample,\n target_fn=model.log_prob,\n sample_shape=[100],\n seed=[4, 8])\n states = tfp.mcmc.sample_chain(\n current_state=init_states,\n ...)"]]