tfp.substrates.jax.mcmc.random_walk_normal_fn

Returns a callable that adds a random normal perturbation to the input.

This function returns a callable that accepts a Python list of Tensors of any shapes and dtypes representing the state parts of the current_state and a random seed. The supplied argument scale must be a Tensor or Python list of Tensors representing the scale of the generated proposal. scale must broadcast with the state parts of current_state. The callable adds a sample from a zero-mean normal distribution with the supplied scales to each state part and returns a same-type list of Tensors as the state parts of current_state.

scale a Tensor or Python list of Tensors of any shapes and dtypes controlling the scale of the normal proposal distribution.
name Python str name prefixed to Ops created by this function. Default value: 'random_walk_normal_fn'.

random_walk_normal_fn A callable accepting a Python list of Tensors representing the state parts of the current_state and an int representing the random seed to be used to generate the proposal. The callable returns the same-type list of Tensors as the input and represents the proposal for the RWM algorithm.