|  View source on GitHub | 
TransformedTransitionKernel applies a bijector to the MCMC's state space.
Inherits From: TransitionKernel
tfp.substrates.jax.mcmc.TransformedTransitionKernel(
    inner_kernel, bijector, name=None
)
The TransformedTransitionKernel TransitionKernel enables fitting
a tfp.bijectors.Bijector which serves to decorrelate the Markov chain Monte
Carlo (MCMC) event dimensions thus making the chain mix faster. This is
particularly useful when the geometry of the target distribution is
unfavorable. In such cases it may take many evaluations of the
target_log_prob_fn for the chain to mix between faraway states.
The idea of training an affine function to decorrelate chain event dims was
presented in [Parno and Marzouk (2014)][1]. Used in conjunction with the
HamiltonianMonteCarlo TransitionKernel, the [Parno and Marzouk (2014)][1]
idea is an instance of Riemannian manifold HMC [(Girolami and Calderhead,
2011)][2].
The TransformedTransitionKernel enables arbitrary bijective transformations
of arbitrary TransitionKernels, e.g., one could use bijectors
tfp.bijectors.ScaleMatvecTriL, tfp.bijectors.RealNVP, etc. with transition
kernels tfp.mcmc.HamiltonianMonteCarlo, tfp.mcmc.RandomWalkMetropolis,
etc.
Transforming nested kernels
TransformedTransitionKernel can operate on multiply nested kernels, as in
the following example:
tfp.mcmc.TransformedTransitionKernel(
  inner_kernel=tfp.mcmc.SimpleStepSizeAdaptation(
    inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
      ... # doesn't matter
    ),
    num_adaptation_steps=9)
  bijector=tfb.Identity()))
Upon construction, TransformedTransitionKernel searches the given
inner_kernel and the "stack" of nested kernels in any inner_kernel
fields thereof until it finds one with a field called target_log_prob_fn,
and replaces this with the transformed function. If no
inner_kernel has such a target log prob a ValueError is raised.
Mathematical Details
TransformedTransitionKernel enables Markov chains which operate in
"unconstrained space." Since we interpret the bijector as mapping
"unconstrained space" to "user space", this means that the MCMC transformed
target_log_prob is:
target_log_prob(bij.forward(x)) + bij.forward_log_det_jacobian(x)
Recall that tfp.distributions.TransformedDistribution uses the inverse to
compute its log_prob. Despite this difference, the use of forward in
TransformedTransitionKernel is perfectly consistent with
TransformedDistribution following the TFP convention of "sampling" being
what defines semantics. The apparent difference is because
TransformedDistribution.log_prob is derived from a user provided
distribution while in TransformedTransitionKernel samples are derived from
target_log_prob_fn. That is, in TransformedDistribution we do:
x ~ NoiseDistribution()
y = bij.forward(x)
log_prob_y = NoiseDistribution().log_prob(bij.inverse(y))
             + bij.inverse_log_det_jacobian(y)
yet in TransformedTransitionKernel we do:
x ~ MCMC()
y = bij.forward(x)
log_prob_y = log_prob(y) + bij.forward_log_det_jacobian(x)
In other words (and in general), tfp.mcmc is derived from a log_prob
which what induces a seeming direction convention change. Aside from TFP
convention, that Bijectors should adhere to "sample first" semantics is
important because it mitigates pervasive necessity of tfp.bijectors.Invert
in user code.
Examples
RealNVP + HamiltonianMonteCarlo
- a 1-layer RealNVP is a pretty weak density model, since it can't change the density of the masked dimensions
- we're not actually training the bijector to do anything useful.
from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfd = tfp.distributions
tfb = tfp.bijectors
def make_likelihood(true_variances):
  return tfd.MultivariateNormalDiag(
      scale_diag=tf.sqrt(true_variances))
dims = 10
dtype = np.float32
true_variances = tf.linspace(dtype(1), dtype(3), dims)
likelihood = make_likelihood(true_variances)
realnvp_hmc = tfp.mcmc.TransformedTransitionKernel(
    inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
      target_log_prob_fn=likelihood.log_prob,
      step_size=0.5,
      num_leapfrog_steps=2),
    bijector=tfb.RealNVP(
      num_masked=2,
      shift_and_log_scale_fn=tfb.real_nvp_default_template(
          hidden_layers=[512, 512])))
states, kernel_results = tfp.mcmc.sample_chain(
    num_results=1000,
    current_state=tf.zeros(dims),
    kernel=realnvp_hmc,
    num_burnin_steps=500)
# Compute sample stats.
sample_mean = tf.reduce_mean(states, axis=0)
sample_var = tf.reduce_mean(
    tf.squared_difference(states, sample_mean),
    axis=0)
References
[1]: Matthew Parno and Youssef Marzouk. Transport map accelerated Markov chain Monte Carlo. arXiv preprint arXiv:1412.5492, 2014. https://arxiv.org/abs/1412.5492
[2]: Mark Girolami and Ben Calderhead. Riemann manifold langevin and hamiltonian monte carlo methods. In Journal of the Royal Statistical Society, 2011. https://doi.org/10.1111/j.1467-9868.2010.00765.x
| Attributes | |
|---|---|
| bijector | |
| experimental_shard_axis_names | The shard axis names for members of the state. | 
| inner_kernel | |
| is_calibrated | Returns Trueif Markov chain converges to specified distribution.
 | 
| name | |
| parameters | Return dictof__init__arguments and their values. | 
Methods
bootstrap_results
bootstrap_results(
    init_state=None, transformed_init_state=None
)
Returns an object with the same type as returned by one_step.
Unlike other TransitionKernels,
TransformedTransitionKernel.bootstrap_results has the option of
initializing the TransformedTransitionKernelResults from either an initial
state, eg, requiring computing bijector.inverse(init_state), or
directly from transformed_init_state, i.e., a Tensor or list
of Tensors which is interpretted as the bijector.inverse
transformed state.
| Args | |
|---|---|
| init_state | Tensoror PythonlistofTensors representing the a
state(s) of the Markov chain(s). Must specifyinit_stateortransformed_init_statebut not both. | 
| transformed_init_state | Tensoror PythonlistofTensors
representing the a state(s) of the Markov chain(s). Must specifyinit_stateortransformed_init_statebut not both. | 
| Returns | |
|---|---|
| kernel_results | A (possibly nested) tuple,namedtupleorlistofTensors representing internal calculations made within this function. | 
| Raises | |
|---|---|
| ValueError | if none of the nested inner_kernelresults contain
the member "target_log_prob". | 
Examples
To use transformed_init_state in context of
tfp.mcmc.sample_chain, you need to explicitly pass the
previous_kernel_results, e.g.,
transformed_kernel = tfp.mcmc.TransformedTransitionKernel(...)
init_state = ...        # Doesnt matter.
transformed_init_state = ... # Does matter.
results = tfp.mcmc.sample_chain(
    num_results=...,
    current_state=init_state,
    previous_kernel_results=transformed_kernel.bootstrap_results(
        transformed_init_state=transformed_init_state),
    trace_fn=None,
    kernel=transformed_kernel)
copy
copy(
    **override_parameter_kwargs
)
Non-destructively creates a deep copy of the kernel.
| Args | |
|---|---|
| **override_parameter_kwargs | Python String/value dictionaryof
initialization arguments to override with new values. | 
| Returns | |
|---|---|
| new_kernel | TransitionKernelobject of same type asself,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,dict(self.parameters, **override_parameters_kwargs). | 
experimental_with_shard_axes
experimental_with_shard_axes(
    shard_axis_names
)
Returns a copy of the kernel with the provided shard axis names.
| Args | |
|---|---|
| shard_axis_names | a structure of strings indicating the shard axis names for each component of this kernel's state. | 
| Returns | |
|---|---|
| A copy of the current kernel with the shard axis information. | 
one_step
one_step(
    current_state, previous_kernel_results, seed=None
)
Runs one iteration of the Transformed Kernel.
| Args | |
|---|---|
| current_state | Tensoror PythonlistofTensors
representing the current state(s) of the Markov chain(s),
after application ofbijector.forward. The firstrdimensions index independent chains,r = tf.rank(target_log_prob_fn(*current_state)). Theinner_kernel.one_stepdoes not actually usecurrent_state,
rather it takes as inputprevious_kernel_results.transformed_state(becauseTransformedTransitionKernelcreates a copy of the input
inner_kernel with a modifiedtarget_log_prob_fnwhich
internally applies thebijector.forward). | 
| previous_kernel_results | collections.namedtuplecontainingTensors
representing values from previous calls to this function (or from thebootstrap_resultsfunction.) | 
| seed | PRNG seed; see tfp.random.sanitize_seedfor details. | 
| Returns | |
|---|---|
| next_state | Tensor or Python list of Tensors representing the state(s)
of the Markov chain(s) after taking exactly one step. Has same type and
shape ascurrent_state. | 
| kernel_results | collections.namedtupleof internal calculations used to
advance the chain. |