tfp.substrates.jax.mcmc.random_walk_uniform_fn

Returns a callable that adds a random uniform perturbation to the input.

For more details on random_walk_uniform_fn, see random_walk_normal_fn. scale might be a Tensor or a list of Tensors that should broadcast with state parts of the current_state. The generated uniform perturbation is sampled as a uniform point on the rectangle [-scale, scale].

scale a Tensor or Python list of Tensors of any shapes and dtypes controlling the upper and lower bound of the uniform proposal distribution.
name Python str name prefixed to Ops created by this function. Default value: 'random_walk_uniform_fn'.

random_walk_uniform_fn A callable accepting a Python list of Tensors representing the state parts of the current_state and an int representing the random seed used to generate the proposal. The callable returns the same-type list of Tensors as the input and represents the proposal for the RWM algorithm.