tfp.experimental.distributions.JointDistributionPinned

A wrapper class for JointDistribution which pins, e.g., the evidence.

This object is experimental; the API may change without warning.

Think of this object as functools.partial for joint distributions. Sampling trims off pinned values (after specifying them as jd.sample(value=pins) to the underlying distribution). Log-density evaluates the joint probability of the given event and the pinned values.

This object represents an unnormalized probability density, and as such is not a tfp.distributions.Distribution, and lacks sample and log_prob methods. In their place, it provides:

  • unnormalized_log_prob, unnormalized_log_prob_parts
  • sample_unpinned, sample_and_log_weight

Mathematically speaking, the object represents a joint probability density, p(x, y) where the x are pinned and the y are unpinned. Accordingly, it is also proportional to p(y | x), up to a (generally) intractable normalizing constant p(x), i.e. p(x, y) = p(y | x) p(x).

A common use-case with probabilistic inference is writing out a generative model to explain some observed data:

jd = tfd.JointDistributionNamed(dict(
  loc = tfd.Normal(0., 1.),
  scale = tfd.Gamma(1., 1.),
  obs = lambda loc, scale: tfd.Normal(loc, scale),
))

Later, when we want to infer 'typical' values of loc and scale conditioned on some given data, we will often write:

def target_log_prob_fn(loc, scale):
  return jd.log_prob(loc=loc, scale=scale, obs=data)

This class enables one to write instead:

partial = tfde.JointDistributionPinned(jd, obs=data)
target_log_prob_fn = partial.unnormalized_log_prob

Or, even more concisely partial = jd.experimental_pin(obs=data).

This is nice, but it wasn't too hard to write out the target_log_prob_fn function explicitly.

Now, let's consider that for many inference and optimization methods, we may want to use a smooth change of variables to perform inference in the unconstrained space of real numbers. In some cases this transformation can be parameter-dependent. For example, if we want to unconstrain the support of tfp.distributions.Uniform(-3., 2.) to the real line, we might use tfp.bijectors.Sigmoid(low=-3., high=2.). In support of such use cases, most distributions (including the JointDistribution* classes) provide a experimental_default_event_space_bijector() method.

When these transformations may be dependent on ancestral parts of a joint distribution, and some of those parameters may be pinned, it is helpful to have a utility class to bridge the gap and provide the multi-part bijective transform. This is the "raison d'etre" of this class.

The model below is somewhat contrived, but demonstrates the use-case.

tfd = tfp.distributions
tfde = tfp.experimental.distributions

n = 75
dim = 3
joint = tfd.JointDistributionNamed(dict(
  upper = tfd.Uniform(.4, 1.5),
  concentration = tfd.Gamma(1., .5),
  corr = lambda concentration: tfd.CholeskyLKJ(
      dim, concentration=concentration),
  stddev = lambda upper: tfd.Sample(tfd.Uniform(.2, upper), dim),
  obs = lambda corr, stddev: tfd.Sample(
      tfd.MultivariateNormalTriL(
          loc=tf.zeros([dim]), scale_tril=corr * stddev[..., tf.newaxis]),
      n)
))
fixed_upper = 1.3
data = joint.sample(upper=fixed_upper)['obs']

pinned = tfde.JointDistributionPinned(joint, upper=fixed_upper, obs=data)
bij = pinned.experimental_default_event_space_bijector()
pulled_back_shape = bij.inverse_event_shape(pinned.event_shape)

# Fit an ensemble using SGD.
batch = 16
uniform_init = tf.nest.map_structure(
    lambda s: tf.random.uniform(tf.concat([[batch], s], axis=0), -2., 2.),
    pulled_back_shape)
vars = tf.nest.map_structure(tf.Variable, uniform_init)

opt = tf.optimizers.Adam(.01)

@tf.function(autograph=False)
def one_step():
  with tf.GradientTape() as tape:
    lp = pinned.unnormalized_log_prob(bij.forward(vars))
  gradients = tape.gradient(lp, vars)
  opt.apply_gradients(zip(gradients.values(), vars.values()))

for _ in range(100):
  one_step()

# Alternatively, sample using MCMC (currently aspirational):
initial_state = bij.forward(uniform_init)

kernel = tfp.mcmc.HamiltonianMonteCarlo(
    target_log_prob_fn=pinned.unnormalized_log_prob,
    step_size=.5, num_leapfrog_steps=4)
# **This line is currently aspirational**, to demonstrate the use-case.
kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bij)
tfp.mcmc.sample_chain(10, kernel=kernel, current_state=initial_state)

distribution A tfp.distributions.JointDistribution.
*pins A single object like the value argument that may be passed into JointDistribution.sample (some parts may be None), or a sequence of objects similar to such sequence as might be passed to JointDistribution.log_prob, but with the difference that some parts may be None (log_prob would require all parts be specified). More precisely, the user may pass (A) a single argument specifiying pins of one or more of the parts of the underlying distribution either by name (i.e. a dict, namedtuple) or by sequence ordering (tuple, list), or (B) a sequence of arguments which align with the model of the underlying distribution (which must be ordered). It is an error to use an unordered sequence of pins with an unordered model, e.g. a tfp.distributions.JointDistributionNamed constructed with a dict model (collections.OrderedDict is allowed).
name Python str name for this distribution. If None, defaults to 'Pinned{distribution.name}'. Default value: None.
**named_pins Named elements to pin. The names given must align with the part names defined by distribution._flat_resolve_names(), i.e. either the explicitly named parts of tfp.distributions.JointDistributionNamed or the name parameters passed to distributions constructed by the model given to JointDistribution*.

batch_shape

distribution The underlying distribution being partially pinned.
dtype DType of unpinned parts.
event_shape Statically resolvable event shapes of unpinned parts.
name Name of this pinned distribution.
pins Dictionary of pins resolved to names.
use_vectorized_map Whether the underlying distribution relies on automatic vectorization.
validate_args

Methods

batch_shape_tensor

View source

event_shape_tensor

View source

Dynamic/graph Tensor event shapes of unpinned parts.

experimental_default_event_space_bijector

View source

A bijector to pull back unpinned values to unconstrained reals.

experimental_pin

View source

Logical equivalent of JointDistribution.experimental_pin.

For example

@tfd.JointDistributionCoroutine
def model():
    x = yield tfd.Normal(0, 1, name='x'),
    y = yield tfd.Normal(0, 1, name='y'),
    yield tfd.Normal(0, 1, name='z')
model.experimental_pin(z=1.).experimental_pin(y=.5).event_shape
# => StructTuple(x=[])

Args
*args Positional arguments: a value structure or component values.
**kwargs Keyword arguments: a value structure or component values. May also include name, specifying a Python string name for ops generated by this method.

Returns
pinned a tfp.experimental.distributions.JointDistributionPinned with the given values pinned in addition to those pins already specified on self.

log_weight

View source

Computes the log relative weight of the given sample.

This function computes the log-probability of the pinned parts at the given location, ignoring the probability of the unpinned parts.

The methods of JointDistributionPinned (unnormalized_log_prob, sample_and_log_weight, etc.) can be called by passing a single structure of tensors, a sequence of tensor arguments, or using named args for each part. For example:

tfde = tfp.experimental.distributions

# Given the following joint distribution:
jd = tfd.JointDistributionSequential([
    tfd.Normal(0., 1., name='z'),
    tfd.Normal(0., 1., name='y'),
    lambda y, z: tfd.Normal(y + z, 1., name='x')
], validate_args=True)

# The following `__init__` styles are all permissible and produce
# `JointDistributionPinned` objects behaving identically.
PartialXY = collections.namedtuple('PartialXY', 'x,y')
PartialX = collections.namedtuple('PartialX', 'x')
OrderedDict = collections.OrderedDict
assert (tfde.JointDistributionPinned(jd, x=2.).pins ==
        tfde.JointDistributionPinned(jd, x=2., z=None).pins ==
        tfde.JointDistributionPinned(jd, dict(x=2.)).pins ==
        tfde.JointDistributionPinned(jd, dict(x=2., y=None)).pins ==
        tfde.JointDistributionPinned(jd, OrderedDict(x=2.)).pins ==
        tfde.JointDistributionPinned(jd, OrderedDict(x=2., y=None)).pins ==
        tfde.JointDistributionPinned(jd, PartialXY(x=2., y=None)).pins ==
        tfde.JointDistributionPinned(jd, PartialX(x=2.)).pins ==
        tfde.JointDistributionPinned(jd, None, None, 2.).pins ==
        tfde.JointDistributionPinned(jd, [None, None, 2.]).pins)
# (Notice that the `pins` attribute is always resolved to a `dict`.)

pinned = tfde.JointDistributionPinned(jd, x=2.)
pinned.dtype
# ==> [tf.float32, tf.float32]
z, y = sample = pinned.sample_unpinned()

# The following calling styles are all permissable and produce the exactly
# the same output.
PartialZY = collections.namedtuple('PartialZY', 'z,y')
assert (pinned.log_weight(sample) ==
        pinned.log_weight(z, y) ==
        pinned.log_weight(z=z, y=y) ==
        pinned.log_weight(PartialZY(z=z, y=y)))

# These calling possibilities also imply that one can also use `*`
# expansion, if `sample` is a sequence:
pinned.log_weight(*sample)
# and similarly, if `sample` is a map, one can use `**` expansion:
pinned.log_weight(**sample)

Component distributions' names are resolved via jd._flat_resolve_names(), which is implemented by each JointDistribution subclass (see subclass documentation for details). Generally, for components where a name was provided---either explicitly as the name argument to a distribution or as a key in a dict-valued JointDistribution, or implicitly, e.g., by the argument name of a JointDistributionSequential distribution-making function---the provided name will be used. Otherwise the component will receive a dummy name; these may change without warning and should not be relied upon.

In general, return types of part-wise methods/properties are determined by those of the underlying JointDistribution's model type:

  • StructTuple for JointDistributionCoroutine, and for JointDistributionNamed with namedtuple model type.
  • collections.OrderedDict for JointDistributionNamed with OrderedDict model type.
  • dict for JointDistributionNamed with dict model type.
  • tuple or list for JointDistributionSequential.
pinned = tfde.JointDistributionPinned(
    tfd.JointDistributionSequential(
        [tfd.Exponential(1.), lambda s: tfd.Normal(0., s)]),
        None, 1.2)
pinned.dtype  # => [tf.float32]
pinned.log_weight([4.])
# ==> Tensor with shape `[]`.
log_wt = pinned.log_weight(4.)
# ==> Tensor with shape `[]`.

Notice that in the first call, [4.] is interpreted as a list of one scalar while in the second call the input is a scalar. Hence both inputs result in identical scalar outputs. If we wanted to pass an explicit vector to the Exponential component---creating a vector-shaped batch of log_weights---we could instead write pinned.log_weight(np.array([4])).

Args
*args Positional arguments: a value structure or component values (see above).
**kwargs Keyword arguments: a value structure or component values (see above). May also include name, specifying a Python string name for ops generated by this method.

Returns
log_weights log-weight of the given point, i.e. the log pinned evidence.

sample_and_log_weight

View source

Draws unnormalized samples and their log-weights with ancestral sampling.

Since this object represents an unnormalized density, we are unable to directly sample the distribution. However, we can evaluate the relative density of different samples. This function returns the relative log-weight alongside the sample. This log-weight is the log-probability of the pinned parts at the sampled location (it differs from unnormalized_log_prob by the log-probability of the unpinned parts).

Args
sample_shape Shape prefix to use when sampling.
seed Optional seed for reproducible sampling.

Returns
samples unpinned parts drawn from the pinned distribution.
log_weights log-weight of the sample. (Log-probability of the pinned parts at the sampled location.)

sample_unpinned

View source

Draws unnormalized samples using ancestral sampling.

Conceptually, this is comparable to calling underlying.sample(value=pins), then stripping away the pinned parts.

Args
sample_shape Shape prefix to use when sampling.
seed Optional seed for reproducible sampling.

Returns
samples unpinned parts sampled from the underlying distribution.

unnormalized_log_prob

View source

Computes the unnormalized log-probability.

The methods of JointDistributionPinned (unnormalized_log_prob, sample_and_log_weight, etc.) can be called by passing a single structure of tensors, a sequence of tensor arguments, or using named args for each part. For example:

tfde = tfp.experimental.distributions

# Given the following joint distribution:
jd = tfd.JointDistributionSequential([
    tfd.Normal(0., 1., name='z'),
    tfd.Normal(0., 1., name='y'),
    lambda y, z: tfd.Normal(y + z, 1., name='x')
], validate_args=True)

# The following `__init__` styles are all permissible and produce
# `JointDistributionPinned` objects behaving identically.
PartialXY = collections.namedtuple('PartialXY', 'x,y')
PartialX = collections.namedtuple('PartialX', 'x')
OrderedDict = collections.OrderedDict
assert (tfde.JointDistributionPinned(jd, x=2.).pins ==
        tfde.JointDistributionPinned(jd, x=2., z=None).pins ==
        tfde.JointDistributionPinned(jd, dict(x=2.)).pins ==
        tfde.JointDistributionPinned(jd, dict(x=2., y=None)).pins ==
        tfde.JointDistributionPinned(jd, OrderedDict(x=2.)).pins ==
        tfde.JointDistributionPinned(jd, OrderedDict(x=2., y=None)).pins ==
        tfde.JointDistributionPinned(jd, PartialXY(x=2., y=None)).pins ==
        tfde.JointDistributionPinned(jd, PartialX(x=2.)).pins ==
        tfde.JointDistributionPinned(jd, None, None, 2.).pins ==
        tfde.JointDistributionPinned(jd, [None, None, 2.]).pins)
# (Notice that the `pins` attribute is always resolved to a `dict`.)

pinned = tfde.JointDistributionPinned(jd, x=2.)
pinned.dtype
# ==> [tf.float32, tf.float32]
z, y = sample = pinned.sample_unpinned()

# The following calling styles are all permissable and produce the exactly
# the same output.
PartialZY = collections.namedtuple('PartialZY', 'z,y')
assert (pinned.unnormalized_log_prob(sample) ==
        pinned.unnormalized_log_prob(z, y) ==
        pinned.unnormalized_log_prob(z=z, y=y) ==
        pinned.unnormalized_log_prob(PartialZY(z=z, y=y)))

# These calling possibilities also imply that one can also use `*`
# expansion, if `sample` is a sequence:
pinned.unnormalized_log_prob(*sample)
# and similarly, if `sample` is a map, one can use `**` expansion:
pinned.unnormalized_log_prob(**sample)

Component distributions' names are resolved via jd._flat_resolve_names(), which is implemented by each JointDistribution subclass (see subclass documentation for details). Generally, for components where a name was provided---either explicitly as the name argument to a distribution or as a key in a dict-valued JointDistribution, or implicitly, e.g., by the argument name of a JointDistributionSequential distribution-making function---the provided name will be used. Otherwise the component will receive a dummy name; these may change without warning and should not be relied upon.

In general, return types of part-wise methods/properties are determined by those of the underlying JointDistribution's model type:

  • StructTuple for JointDistributionCoroutine, and for JointDistributionNamed with namedtuple model type.
  • collections.OrderedDict for JointDistributionNamed with OrderedDict model type.
  • dict for JointDistributionNamed with dict model type.
  • tuple or list for JointDistributionSequential.
pinned = tfde.JointDistributionPinned(
    tfd.JointDistributionSequential(
        [tfd.Exponential(1.), lambda s: tfd.Normal(0., s)]),
        None, 1.2)
pinned.dtype  # => [tf.float32]
pinned.unnormalized_log_prob([4.])
# ==> Tensor with shape `[]`.
lp = pinned.unnormalized_log_prob(4.)
# ==> Tensor with shape `[]`.

Notice that in the first call, [4.] is interpreted as a list of one scalar while in the second call the input is a scalar. Hence both inputs result in identical scalar outputs. If we wanted to pass an explicit vector to the Exponential component---creating a vector-shaped batch of unnormalized_log_probs---we could instead write pinned.unnormalized_log_prob(np.array([4])).

Args
*args Positional arguments: a value structure or component values (see above).
**kwargs Keyword arguments: a value structure or component values (see above). May also include name, specifying a Python string name for ops generated by this method.

Returns
unnormalized_log_prob The joint log-probability of *xs or **kwargs with the pinned parts. It is unnormalized with respect to *xs or **kwargs.

unnormalized_log_prob_parts

View source

Computes the unnormalized log-probability of each part.

The methods of JointDistributionPinned (unnormalized_log_prob, sample_and_log_weight, etc.) can be called by passing a single structure of tensors, a sequence of tensor arguments, or using named args for each part. For example:

tfde = tfp.experimental.distributions

# Given the following joint distribution:
jd = tfd.JointDistributionSequential([
    tfd.Normal(0., 1., name='z'),
    tfd.Normal(0., 1., name='y'),
    lambda y, z: tfd.Normal(y + z, 1., name='x')
], validate_args=True)

# The following `__init__` styles are all permissible and produce
# `JointDistributionPinned` objects behaving identically.
PartialXY = collections.namedtuple('PartialXY', 'x,y')
PartialX = collections.namedtuple('PartialX', 'x')
OrderedDict = collections.OrderedDict
assert (tfde.JointDistributionPinned(jd, x=2.).pins ==
        tfde.JointDistributionPinned(jd, x=2., z=None).pins ==
        tfde.JointDistributionPinned(jd, dict(x=2.)).pins ==
        tfde.JointDistributionPinned(jd, dict(x=2., y=None)).pins ==
        tfde.JointDistributionPinned(jd, OrderedDict(x=2.)).pins ==
        tfde.JointDistributionPinned(jd, OrderedDict(x=2., y=None)).pins ==
        tfde.JointDistributionPinned(jd, PartialXY(x=2., y=None)).pins ==
        tfde.JointDistributionPinned(jd, PartialX(x=2.)).pins ==
        tfde.JointDistributionPinned(jd, None, None, 2.).pins ==
        tfde.JointDistributionPinned(jd, [None, None, 2.]).pins)
# (Notice that the `pins` attribute is always resolved to a `dict`.)

pinned = tfde.JointDistributionPinned(jd, x=2.)
pinned.dtype
# ==> [tf.float32, tf.float32]
z, y = sample = pinned.sample_unpinned()

# The following calling styles are all permissable and produce the exactly
# the same output.
PartialZY = collections.namedtuple('PartialZY', 'z,y')
assert (pinned.unnormalized_log_prob_parts(sample) ==
        pinned.unnormalized_log_prob_parts(z, y) ==
        pinned.unnormalized_log_prob_parts(z=z, y=y) ==
        pinned.unnormalized_log_prob_parts(PartialZY(z=z, y=y)))

# These calling possibilities also imply that one can also use `*`
# expansion, if `sample` is a sequence:
pinned.unnormalized_log_prob_parts(*sample)
# and similarly, if `sample` is a map, one can use `**` expansion:
pinned.unnormalized_log_prob_parts(**sample)

Component distributions' names are resolved via jd._flat_resolve_names(), which is implemented by each JointDistribution subclass (see subclass documentation for details). Generally, for components where a name was provided---either explicitly as the name argument to a distribution or as a key in a dict-valued JointDistribution, or implicitly, e.g., by the argument name of a JointDistributionSequential distribution-making function---the provided name will be used. Otherwise the component will receive a dummy name; these may change without warning and should not be relied upon.

In general, return types of part-wise methods/properties are determined by those of the underlying JointDistribution's model type:

  • StructTuple for JointDistributionCoroutine, and for JointDistributionNamed with namedtuple model type.
  • collections.OrderedDict for JointDistributionNamed with OrderedDict model type.
  • dict for JointDistributionNamed with dict model type.
  • tuple or list for JointDistributionSequential.
pinned = tfde.JointDistributionPinned(
    tfd.JointDistributionSequential(
        [tfd.Exponential(1.), lambda s: tfd.Normal(0., s)]),
        None, 1.2)
pinned.dtype  # => [tf.float32]
pinned.unnormalized_log_prob_parts([4.])
# ==> Tensor with shape `[]`.
lp_parts = pinned.unnormalized_log_prob_parts(4.)
# ==> Tensor with shape `[]`.

Notice that in the first call, [4.] is interpreted as a list of one scalar while in the second call the input is a scalar. Hence both inputs result in identical scalar outputs. If we wanted to pass an explicit vector to the Exponential component---creating a vector-shaped batch of unnormalized_log_prob_partss---we could instead write pinned.unnormalized_log_prob_parts(np.array([4])).

Args
*args Positional arguments: a value structure or component values (see above).
**kwargs Keyword arguments: a value structure or component values (see above). May also include name, specifying a Python string name for ops generated by this method.

Returns
pinned partial log-prob of each pinned part
unpinned partial log-prob of each unpinned part