
Shards a transition kernel across a named axis.

Inherits From: TransitionKernel

Ordinarily, one can produce independent Markov chains from a single kernel by proving a batch of states but when using named axes inside of a map (say in the case of using JAX's pmap, vmap, or xmap), the kernel is provided with state without batch dimensions. In order to sample independently across the named axis, the PRNG seed across the named axis must be different. This can be accomplished by folding the named axis index into the random seed. A Sharded kernel does exactly this, creating independent chains across a named axis.

inner_kernel A TransitionKernel to be sharded.
chain_axis_names A str or list of strs that determine the named axes that independent Markov chains will be sharded across.
validate_args Python bool. When True kernel parameters are checked for validity. When False invalid inputs may silently render incorrect outputs.
name Python str name prefixed to Ops created by this class.


experimental_shard_axis_names The shard axis names for members of the state.

is_calibrated Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.




View source

Returns an object with the same type as returned by one_step(...)[1].

init_state Tensor or Python list of Tensors representing the initial state(s) of the Markov chain(s).

kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function.


View source

Non-destructively creates a deep copy of the kernel.

**override_parameter_kwargs Python String/value dictionary of initialization arguments to override with new values.

new_kernel TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs).


View source

Returns a copy of the kernel with the provided shard axis names.

shard_axis_names a structure of strings indicating the shard axis names for each component of this kernel's state.

A copy of the current kernel with the shard axis information.


View source

Takes one step of the TransitionKernel.

Must be overridden by subclasses.

current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s).
previous_kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within the previous call to this function (or as returned by bootstrap_results).
seed PRNG seed; see tfp.random.sanitize_seed for details.

next_state Tensor or Python list of Tensors representing the next state(s) of the Markov chain(s).
kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function.