States and auxiliary trace of an MCMC chain.
tfp.substrates.jax.mcmc.CheckpointableStatesAndTrace(
all_states, trace, final_kernel_results
)
The first dimension of all the Tensor
s in the all_states
and trace
attributes is the same and represents the chain length.
Attributes |
all_states
|
A Tensor or a nested collection of Tensor s representing the
MCMC chain state.
|
trace
|
A Tensor or a nested collection of Tensor s representing the
auxiliary values traced alongside the chain.
|
final_kernel_results
|
A Tensor or a nested collection of Tensor s
representing the final value of the auxiliary state of the
TransitionKernel that generated this chain.
|