View source on GitHub |
Tries an MCMC initialization proposal until it gets a valid state.
tfp.experimental.mcmc.retry_init(
proposal_fn,
target_fn,
*args,
max_trials=50,
seed=None,
name=None,
**kwargs
)
Used in the notebooks
Used in the tutorials |
---|
In this case, "valid" is defined as the value of target_fn
is
finite. This corresponds to an MCMC workflow where target_fn
compute the log-probability one wants to sample from, in which case
"finite target_fn
" means "finite and positive probability state".
If target_fn
returns a Tensor of size greater than 1, the results
are assumed to be independent of each other, so that different batch
members can be accepted individually.
The method is bounded rejection sampling. The bound serves to avoid
wasting computation on hopeless initialization procedures. In
interactive MCMC, one would presumably rather come up with a better
initialization proposal than wait for an unbounded number of
attempts with a bad one. If unbounded re-trials are desired,
set max_trials
to None
.
Args | |
---|---|
proposal_fn
|
A function accepting a seed keyword argument and no other
required arguments which generates proposed initial states.
|
target_fn
|
A function accepting the return value of proposal_fn
and returning a floating-point Tensor.
|
*args
|
Additional arguments passed to proposal_fn .
|
max_trials
|
Size-1 integer Tensor or None. Maximum number of
calls to proposal_fn to attempt. If acceptable states are not
found in this many trials, retry_init signals an error. If
None , there is no limit, and retry_init skips the control
flow cost of checking for success.
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
name
|
Python str name prefixed to Ops created by this function.
Default value: None (i.e., 'mcmc_sample_chain').
|
**kwargs
|
Additional keyword arguments passed to proposal_fn .
|
Returns | |
---|---|
states
|
An acceptable result from proposal_fn .
|
Example
One popular MCMC initialization scheme is to start the chains near 0
in unconstrained space. There are models where the unconstraining
transformation cannot exactly capture the space of valid states,
such that this initialization has some material but not overwhelming
chance of failure. In this case, we can use retry_init
to compensate.
@tfp.distributions.JointDistributionCoroutine
def model():
...
raw_init_dist = tfp.experimental.mcmc.init_near_unconstrained_zero(model)
init_states = tfp.experimental.mcmc.retry_init(
proposal_fn=raw_init_dist.sample,
target_fn=model.log_prob,
sample_shape=[100],
seed=[4, 8])
states = tfp.mcmc.sample_chain(
current_state=init_states,
...)