Shards a transition kernel across a named axis.
Inherits From: TransitionKernel
tfp.experimental.mcmc.Sharded(
inner_kernel, chain_axis_names, validate_args=False, name=None
)
Ordinarily, one can produce independent Markov chains from a single kernel by
proving a batch of states but when using named axes inside of a map (say
in the case of using JAX's pmap
, vmap
, or xmap
), the kernel is provided
with state without batch dimensions. In order to sample independently across
the named axis, the PRNG seed across the named axis must be different. This
can be accomplished by folding the named axis index into the random seed.
A Sharded
kernel does exactly this, creating independent chains across a
named axis.
Attributes | |
---|---|
chain_axis_names
|
|
experimental_shard_axis_names
|
The shard axis names for members of the state. |
inner_kernel
|
|
is_calibrated
|
Returns True if Markov chain converges to specified distribution.
|
parameters
|
Methods
bootstrap_results
bootstrap_results(
init_state
)
Returns an object with the same type as returned by one_step(...)[1]
.
Args | |
---|---|
init_state
|
Tensor or Python list of Tensor s representing the
initial state(s) of the Markov chain(s).
|
Returns | |
---|---|
kernel_results
|
A (possibly nested) tuple , namedtuple or list of
Tensor s representing internal calculations made within this function.
|
copy
copy(
**override_parameter_kwargs
)
Non-destructively creates a deep copy of the kernel.
Args | |
---|---|
**override_parameter_kwargs
|
Python String/value dictionary of
initialization arguments to override with new values.
|
Returns | |
---|---|
new_kernel
|
TransitionKernel object of same type as self ,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,
dict(self.parameters, **override_parameters_kwargs) .
|
experimental_with_shard_axes
experimental_with_shard_axes(
shard_axis_names
)
Returns a copy of the kernel with the provided shard axis names.
Args | |
---|---|
shard_axis_names
|
a structure of strings indicating the shard axis names for each component of this kernel's state. |
Returns | |
---|---|
A copy of the current kernel with the shard axis information. |
one_step
one_step(
current_state, previous_kernel_results, seed=None
)
Takes one step of the TransitionKernel.
Must be overridden by subclasses.
Args | |
---|---|
current_state
|
Tensor or Python list of Tensor s representing the
current state(s) of the Markov chain(s).
|
previous_kernel_results
|
A (possibly nested) tuple , namedtuple or
list of Tensor s representing internal calculations made within the
previous call to this function (or as returned by bootstrap_results ).
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
next_state
|
Tensor or Python list of Tensor s representing the
next state(s) of the Markov chain(s).
|
kernel_results
|
A (possibly nested) tuple , namedtuple or list of
Tensor s representing internal calculations made within this function.
|