tfp.experimental.mcmc.retry_init

Tries an MCMC initialization proposal until it gets a valid state.

Used in the notebooks

Used in the tutorials

In this case, "valid" is defined as the value of target_fn is finite. This corresponds to an MCMC workflow where target_fn compute the log-probability one wants to sample from, in which case "finite target_fn" means "finite and positive probability state". If target_fn returns a Tensor of size greater than 1, the results are assumed to be independent of each other, so that different batch members can be accepted individually.

The method is bounded rejection sampling. The bound serves to avoid wasting computation on hopeless initialization procedures. In interactive MCMC, one would presumably rather come up with a better initialization proposal than wait for an unbounded number of attempts with a bad one. If unbounded re-trials are desired, set max_trials to None.

proposal_fn A function accepting a seed keyword argument and no other required arguments which generates proposed initial states.
target_fn A function accepting the return value of proposal_fn and returning a floating-point Tensor.
*args Additional arguments passed to proposal_fn.
max_trials Size-1 integer Tensor or None. Maximum number of calls to proposal_fn to attempt. If acceptable states are not found in this many trials, retry_init signals an error. If None, there is no limit, and retry_init skips the control flow cost of checking for success.
seed Optional, a PRNG seed for reproducible sampling.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'mcmc_sample_chain').
**kwargs Additional keyword arguments passed to proposal_fn.

states An acceptable result from proposal_fn.

Example

One popular MCMC initialization scheme is to start the chains near 0 in unconstrained space. There are models where the unconstraining transformation cannot exactly capture the space of valid states, such that this initialization has some material but not overwhelming chance of failure. In this case, we can use retry_init to compensate.

@tfp.distributions.JointDistributionCoroutine
def model():
  ...

raw_init_dist = tfp.experimental.mcmc.init_near_unconstrained_zero(model)
init_states = tfp.experimental.mcmc.retry_init(
  proposal_fn=raw_init_dist.sample,
  target_fn=model.log_prob,
  sample_shape=[100],
  seed=[4, 8])
states = tfp.mcmc.sample_chain(
  current_state=init_states,
  ...)