View source on GitHub |
A wrapper class for JointDistribution
which pins, e.g., the evidence.
tfp.experimental.distributions.JointDistributionPinned(
distribution, *pins, name=None, **named_pins
)
This object is experimental; the API may change without warning.
Think of this object as functools.partial
for joint distributions. Sampling
trims off pinned values (after specifying them as jd.sample(value=pins)
to
the underlying distribution). Log-density evaluates the joint probability of
the given event and the pinned values.
Mathematically speaking, the object represents a joint probability density,
p(x, y)
where the x
are pinned and the y
are unpinned. Accordingly, it
is also proportional to p(y | x)
, up to a (generally) intractable
normalizing constant p(x)
, i.e. p(x, y) = p(y | x) p(x)
.
The measure on the unpinned values y
that this object represents
is no longer a probability measure, in that the integral of
log_prob
over the space of y
is p(x)
, which in general is not
- As such, this is not a
tfp.distributions.Distribution
, and lacks asample
method. In its place, it provides:
sample_unpinned
,sample_and_log_weight
A common use-case with probabilistic inference is writing out a generative model to explain some observed data:
jd = tfd.JointDistributionNamed(dict(
loc = tfd.Normal(0., 1.),
scale = tfd.Gamma(1., 1.),
obs = lambda loc, scale: tfd.Normal(loc, scale),
))
Later, when we want to infer 'typical' values of loc
and scale
conditioned
on some given data
, we will often write:
def target_log_prob_fn(loc, scale):
return jd.log_prob(loc=loc, scale=scale, obs=data)
This class enables one to write instead:
partial = tfde.JointDistributionPinned(jd, obs=data)
target_log_prob_fn = partial.log_prob
Or, even more concisely partial = jd.experimental_pin(obs=data)
.
This is nice, but it wasn't too hard to write out the target_log_prob_fn
function explicitly.
Now, let's consider that for many inference and optimization methods, we may
want to use a smooth change of variables to perform inference in the
unconstrained space of real numbers. In some cases this transformation can be
parameter-dependent. For example, if we want to unconstrain the support of
tfp.distributions.Uniform(-3., 2.)
to the real line, we might use
tfp.bijectors.Sigmoid(low=-3., high=2.)
. In support of such use cases,
most distributions (including the JointDistribution*
classes) provide a
experimental_default_event_space_bijector()
method.
When these transformations may be dependent on ancestral parts of a joint distribution, and some of those parameters may be pinned, it is helpful to have a utility class to bridge the gap and provide the multi-part bijective transform. This is the "raison d'être" of this class.
The model below is somewhat contrived, but demonstrates the use-case.
tfd = tfp.distributions
tfde = tfp.experimental.distributions
n = 75
dim = 3
joint = tfd.JointDistributionNamed(dict(
upper = tfd.Uniform(.4, 1.5),
concentration = tfd.Gamma(1., .5),
corr = lambda concentration: tfd.CholeskyLKJ(
dim, concentration=concentration),
stddev = lambda upper: tfd.Sample(tfd.Uniform(.2, upper), dim),
obs = lambda corr, stddev: tfd.Sample(
tfd.MultivariateNormalTriL(
loc=tf.zeros([dim]), scale_tril=corr * stddev[..., tf.newaxis]),
n)
))
fixed_upper = 1.3
data = joint.sample(upper=fixed_upper)['obs']
pinned = tfde.JointDistributionPinned(joint, upper=fixed_upper, obs=data)
bij = pinned.experimental_default_event_space_bijector()
pulled_back_shape = bij.inverse_event_shape(pinned.event_shape)
# Fit an ensemble using SGD.
batch = 16
uniform_init = tf.nest.map_structure(
lambda s: tf.random.uniform(tf.concat([[batch], s], axis=0), -2., 2.),
pulled_back_shape)
vars = tf.nest.map_structure(tf.Variable, uniform_init)
opt = tf.optimizers.Adam(.01)
@tf.function(autograph=False)
def one_step():
with tf.GradientTape() as tape:
lp = pinned.log_prob(bij.forward(vars))
gradients = tape.gradient(lp, vars)
opt.apply_gradients(zip(gradients.values(), vars.values()))
for _ in range(100):
one_step()
# Alternatively, sample using MCMC (currently aspirational):
initial_state = bij.forward(uniform_init)
kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=pinned.log_prob,
step_size=.5, num_leapfrog_steps=4)
# **This line is currently aspirational**, to demonstrate the use-case.
kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bij)
tfp.mcmc.sample_chain(10, kernel=kernel, current_state=initial_state)
Args | |
---|---|
distribution
|
A tfp.distributions.JointDistribution .
|
*pins
|
A single object like the value argument that may be passed into
JointDistribution.sample (some parts may be None ), or a sequence of
objects similar to such sequence as might be passed to
JointDistribution.log_prob , but with the difference that some parts
may be None (log_prob would require all parts be specified).
More precisely, the user may pass (A) a single argument specifiying pins
of one or more of the parts of the underlying distribution either by
name (i.e. a dict , namedtuple ) or by sequence ordering (tuple ,
list ), or (B) a sequence of arguments which align with the model of
the underlying distribution (which must be ordered). It is an error to
use an unordered sequence of pins with an unordered model, e.g. a
tfp.distributions.JointDistributionNamed constructed with a dict
model (collections.OrderedDict is allowed).
|
name
|
Python str name for this distribution. If None , defaults to
'Pinned{distribution.name}'.
Default value: None .
|
**named_pins
|
Named elements to pin. The names given must align with the
part names defined by distribution._flat_resolve_names() , i.e. either
the explicitly named parts of tfp.distributions.JointDistributionNamed
or the name parameters passed to distributions constructed by the
model given to JointDistribution* .
|
Methods
batch_shape_tensor
batch_shape_tensor()
event_shape_tensor
event_shape_tensor()
Dynamic/graph Tensor event shapes of unpinned parts.
experimental_default_event_space_bijector
experimental_default_event_space_bijector(
*args, **kwargs
)
A bijector to pull back unpinned values to unconstrained reals.
experimental_pin
experimental_pin(
*args, **kwargs
)
Logical equivalent of JointDistribution.experimental_pin
.
For example
@tfd.JointDistributionCoroutine
def model():
x = yield tfd.Normal(0, 1, name='x'),
y = yield tfd.Normal(0, 1, name='y'),
yield tfd.Normal(0, 1, name='z')
model.experimental_pin(z=1.).experimental_pin(y=.5).event_shape
# => StructTuple(x=[])
Args | |
---|---|
*args
|
Positional arguments: a value structure or component values. |
**kwargs
|
Keyword arguments: a value structure or component values.
May also include name , specifying a Python string name for ops
generated by this method.
|
Returns | |
---|---|
pinned
|
a tfp.experimental.distributions.JointDistributionPinned with
the given values pinned in addition to those pins already specified on
self .
|
experimental_sample_and_log_prob
experimental_sample_and_log_prob(
sample_shape=(), seed=None, name='sample_and_log_prob', **kwargs
)
log_prob
log_prob(
*args, **kwargs
)
Computes the log-density of this measure at the given point.
The measure on the space of unpinned values that is represented by
this object is not a probability measure, so the values of
log_prob
will in general not integrate to 1.
The methods of JointDistributionPinned
(unnormalized_log_prob
,
sample_and_log_weight
, etc.) can be called by passing a single structure
of tensors, a sequence of tensor arguments, or using named args for each
part. For example:
tfde = tfp.experimental.distributions
# Given the following joint distribution:
jd = tfd.JointDistributionSequential([
tfd.Normal(0., 1., name='z'),
tfd.Normal(0., 1., name='y'),
lambda y, z: tfd.Normal(y + z, 1., name='x')
], validate_args=True)
# The following `__init__` styles are all permissible and produce
# `JointDistributionPinned` objects behaving identically.
PartialXY = collections.namedtuple('PartialXY', 'x,y')
PartialX = collections.namedtuple('PartialX', 'x')
OrderedDict = collections.OrderedDict
assert (tfde.JointDistributionPinned(jd, x=2.).pins ==
tfde.JointDistributionPinned(jd, x=2., z=None).pins ==
tfde.JointDistributionPinned(jd, dict(x=2.)).pins ==
tfde.JointDistributionPinned(jd, dict(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, OrderedDict(x=2.)).pins ==
tfde.JointDistributionPinned(jd, OrderedDict(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, PartialXY(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, PartialX(x=2.)).pins ==
tfde.JointDistributionPinned(jd, None, None, 2.).pins ==
tfde.JointDistributionPinned(jd, [None, None, 2.]).pins)
# (Notice that the `pins` attribute is always resolved to a `dict`.)
pinned = tfde.JointDistributionPinned(jd, x=2.)
pinned.dtype
# ==> [tf.float32, tf.float32]
z, y = sample = pinned.sample_unpinned()
# The following calling styles are all permissable and produce the exactly
# the same output.
PartialZY = collections.namedtuple('PartialZY', 'z,y')
assert (pinned.log_prob(sample) ==
pinned.log_prob(z, y) ==
pinned.log_prob(z=z, y=y) ==
pinned.log_prob(PartialZY(z=z, y=y)))
# These calling possibilities also imply that one can also use `*`
# expansion, if `sample` is a sequence:
pinned.log_prob(*sample)
# and similarly, if `sample` is a map, one can use `**` expansion:
pinned.log_prob(**sample)
Component distributions' names are resolved via jd._flat_resolve_names()
,
which is implemented by each JointDistribution
subclass (see subclass
documentation for details). Generally, for components where a name was
provided---either explicitly as the name
argument to a distribution or as
a key in a dict-valued JointDistribution, or implicitly, e.g., by the
argument name of a JointDistributionSequential
distribution-making
function---the provided name will be used. Otherwise the component will
receive a dummy name; these may change without warning and should not be
relied upon.
In general, return types of part-wise methods/properties are determined by
those of the underlying JointDistribution
's model type:
StructTuple
forJointDistributionCoroutine
, and forJointDistributionNamed
withnamedtuple
model type.collections.OrderedDict
forJointDistributionNamed
withOrderedDict
model type.dict
forJointDistributionNamed
withdict
model type.tuple
orlist
forJointDistributionSequential
.
pinned = tfde.JointDistributionPinned(
tfd.JointDistributionSequential(
[tfd.Exponential(1.), lambda s: tfd.Normal(0., s)]),
None, 1.2)
pinned.dtype # => [tf.float32]
pinned.log_prob([4.])
# ==> Tensor with shape `[]`.
lp = pinned.log_prob(4.)
# ==> Tensor with shape `[]`.
Notice that in the first call, [4.]
is interpreted as a list of one
scalar while in the second call the input is a scalar. Hence both inputs
result in identical scalar outputs. If we wanted to pass an explicit
vector to the Exponential
component---creating a vector-shaped batch
of log_prob
s---we could instead write
pinned.log_prob(np.array([4]))
.
Args | |
---|---|
*args
|
Positional arguments: a value structure or component values (see above). |
**kwargs
|
Keyword arguments: a value structure or component values
(see above). May also include name , specifying a Python string name
for ops generated by this method.
|
Returns | |
---|---|
log_prob
|
The joint log-probability of *xs or **kwargs
with the pinned parts. It is unnormalized with respect to *xs or
**kwargs .
|
log_prob_parts
log_prob_parts(
*args, **kwargs
)
Computes the log-probability of each part.
The measure on the space of unpinned values that is represented by
this object is not a probability measure, so the values produced
by log_prob_parts
will in general not integrate to 1.
The methods of JointDistributionPinned
(unnormalized_log_prob
,
sample_and_log_weight
, etc.) can be called by passing a single structure
of tensors, a sequence of tensor arguments, or using named args for each
part. For example:
tfde = tfp.experimental.distributions
# Given the following joint distribution:
jd = tfd.JointDistributionSequential([
tfd.Normal(0., 1., name='z'),
tfd.Normal(0., 1., name='y'),
lambda y, z: tfd.Normal(y + z, 1., name='x')
], validate_args=True)
# The following `__init__` styles are all permissible and produce
# `JointDistributionPinned` objects behaving identically.
PartialXY = collections.namedtuple('PartialXY', 'x,y')
PartialX = collections.namedtuple('PartialX', 'x')
OrderedDict = collections.OrderedDict
assert (tfde.JointDistributionPinned(jd, x=2.).pins ==
tfde.JointDistributionPinned(jd, x=2., z=None).pins ==
tfde.JointDistributionPinned(jd, dict(x=2.)).pins ==
tfde.JointDistributionPinned(jd, dict(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, OrderedDict(x=2.)).pins ==
tfde.JointDistributionPinned(jd, OrderedDict(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, PartialXY(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, PartialX(x=2.)).pins ==
tfde.JointDistributionPinned(jd, None, None, 2.).pins ==
tfde.JointDistributionPinned(jd, [None, None, 2.]).pins)
# (Notice that the `pins` attribute is always resolved to a `dict`.)
pinned = tfde.JointDistributionPinned(jd, x=2.)
pinned.dtype
# ==> [tf.float32, tf.float32]
z, y = sample = pinned.sample_unpinned()
# The following calling styles are all permissable and produce the exactly
# the same output.
PartialZY = collections.namedtuple('PartialZY', 'z,y')
assert (pinned.log_prob_parts(sample) ==
pinned.log_prob_parts(z, y) ==
pinned.log_prob_parts(z=z, y=y) ==
pinned.log_prob_parts(PartialZY(z=z, y=y)))
# These calling possibilities also imply that one can also use `*`
# expansion, if `sample` is a sequence:
pinned.log_prob_parts(*sample)
# and similarly, if `sample` is a map, one can use `**` expansion:
pinned.log_prob_parts(**sample)
Component distributions' names are resolved via jd._flat_resolve_names()
,
which is implemented by each JointDistribution
subclass (see subclass
documentation for details). Generally, for components where a name was
provided---either explicitly as the name
argument to a distribution or as
a key in a dict-valued JointDistribution, or implicitly, e.g., by the
argument name of a JointDistributionSequential
distribution-making
function---the provided name will be used. Otherwise the component will
receive a dummy name; these may change without warning and should not be
relied upon.
In general, return types of part-wise methods/properties are determined by
those of the underlying JointDistribution
's model type:
StructTuple
forJointDistributionCoroutine
, and forJointDistributionNamed
withnamedtuple
model type.collections.OrderedDict
forJointDistributionNamed
withOrderedDict
model type.dict
forJointDistributionNamed
withdict
model type.tuple
orlist
forJointDistributionSequential
.
pinned = tfde.JointDistributionPinned(
tfd.JointDistributionSequential(
[tfd.Exponential(1.), lambda s: tfd.Normal(0., s)]),
None, 1.2)
pinned.dtype # => [tf.float32]
pinned.log_prob_parts([4.])
# ==> Tensor with shape `[]`.
lp_parts = pinned.log_prob_parts(4.)
# ==> Tensor with shape `[]`.
Notice that in the first call, [4.]
is interpreted as a list of one
scalar while in the second call the input is a scalar. Hence both inputs
result in identical scalar outputs. If we wanted to pass an explicit
vector to the Exponential
component---creating a vector-shaped batch
of log_prob_parts
s---we could instead write
pinned.log_prob_parts(np.array([4]))
.
Args | |
---|---|
*args
|
Positional arguments: a value structure or component values (see above). |
**kwargs
|
Keyword arguments: a value structure or component values
(see above). May also include name , specifying a Python string name
for ops generated by this method.
|
Returns | |
---|---|
pinned
|
partial log-prob of each pinned part |
unpinned
|
partial log-prob of each unpinned part |
log_weight
log_weight(
*args, **kwargs
)
Computes the log relative weight of the given sample.
This function computes the log-probability of the pinned parts at the given location, ignoring the probability of the unpinned parts.
The methods of JointDistributionPinned
(unnormalized_log_prob
,
sample_and_log_weight
, etc.) can be called by passing a single structure
of tensors, a sequence of tensor arguments, or using named args for each
part. For example:
tfde = tfp.experimental.distributions
# Given the following joint distribution:
jd = tfd.JointDistributionSequential([
tfd.Normal(0., 1., name='z'),
tfd.Normal(0., 1., name='y'),
lambda y, z: tfd.Normal(y + z, 1., name='x')
], validate_args=True)
# The following `__init__` styles are all permissible and produce
# `JointDistributionPinned` objects behaving identically.
PartialXY = collections.namedtuple('PartialXY', 'x,y')
PartialX = collections.namedtuple('PartialX', 'x')
OrderedDict = collections.OrderedDict
assert (tfde.JointDistributionPinned(jd, x=2.).pins ==
tfde.JointDistributionPinned(jd, x=2., z=None).pins ==
tfde.JointDistributionPinned(jd, dict(x=2.)).pins ==
tfde.JointDistributionPinned(jd, dict(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, OrderedDict(x=2.)).pins ==
tfde.JointDistributionPinned(jd, OrderedDict(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, PartialXY(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, PartialX(x=2.)).pins ==
tfde.JointDistributionPinned(jd, None, None, 2.).pins ==
tfde.JointDistributionPinned(jd, [None, None, 2.]).pins)
# (Notice that the `pins` attribute is always resolved to a `dict`.)
pinned = tfde.JointDistributionPinned(jd, x=2.)
pinned.dtype
# ==> [tf.float32, tf.float32]
z, y = sample = pinned.sample_unpinned()
# The following calling styles are all permissable and produce the exactly
# the same output.
PartialZY = collections.namedtuple('PartialZY', 'z,y')
assert (pinned.log_weight(sample) ==
pinned.log_weight(z, y) ==
pinned.log_weight(z=z, y=y) ==
pinned.log_weight(PartialZY(z=z, y=y)))
# These calling possibilities also imply that one can also use `*`
# expansion, if `sample` is a sequence:
pinned.log_weight(*sample)
# and similarly, if `sample` is a map, one can use `**` expansion:
pinned.log_weight(**sample)
Component distributions' names are resolved via jd._flat_resolve_names()
,
which is implemented by each JointDistribution
subclass (see subclass
documentation for details). Generally, for components where a name was
provided---either explicitly as the name
argument to a distribution or as
a key in a dict-valued JointDistribution, or implicitly, e.g., by the
argument name of a JointDistributionSequential
distribution-making
function---the provided name will be used. Otherwise the component will
receive a dummy name; these may change without warning and should not be
relied upon.
In general, return types of part-wise methods/properties are determined by
those of the underlying JointDistribution
's model type:
StructTuple
forJointDistributionCoroutine
, and forJointDistributionNamed
withnamedtuple
model type.collections.OrderedDict
forJointDistributionNamed
withOrderedDict
model type.dict
forJointDistributionNamed
withdict
model type.tuple
orlist
forJointDistributionSequential
.
pinned = tfde.JointDistributionPinned(
tfd.JointDistributionSequential(
[tfd.Exponential(1.), lambda s: tfd.Normal(0., s)]),
None, 1.2)
pinned.dtype # => [tf.float32]
pinned.log_weight([4.])
# ==> Tensor with shape `[]`.
log_wt = pinned.log_weight(4.)
# ==> Tensor with shape `[]`.
Notice that in the first call, [4.]
is interpreted as a list of one
scalar while in the second call the input is a scalar. Hence both inputs
result in identical scalar outputs. If we wanted to pass an explicit
vector to the Exponential
component---creating a vector-shaped batch
of log_weight
s---we could instead write
pinned.log_weight(np.array([4]))
.
Args | |
---|---|
*args
|
Positional arguments: a value structure or component values (see above). |
**kwargs
|
Keyword arguments: a value structure or component values
(see above). May also include name , specifying a Python string name
for ops generated by this method.
|
Returns | |
---|---|
log_weights
|
log-weight of the given point, i.e. the log pinned evidence. |
sample_and_log_weight
sample_and_log_weight(
sample_shape=(), seed=None
)
Draws unnormalized samples and their log-weights with ancestral sampling.
This method provides the ancestral importance sampler for the
non-probability measure on the unpinned space represented by this
object. To wit, we draw a sample with the same distribution as
sample_unpinned
, and return it together with a weight given by
the log-probability of the pinned parts.
Denoting the unpinned space by Y
, the samples by y
, their
corresponding log-weights by w
, and the measure defined by this
object by m
, the spec for this method is that, for any
real-valued function f : Y -> R
,
int_Y f(y) dm = E [f(y) * exp(w)].
Args | |
---|---|
sample_shape
|
Shape prefix to use when sampling. |
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
samples
|
unpinned parts ancestrally sampled from the pinned distribution. |
log_weights
|
log-weight of the sample. (Log-probability of the pinned parts at the sampled location.) |
sample_unpinned
sample_unpinned(
sample_shape=(), seed=None
)
Draws unnormalized samples using ancestral sampling.
This produces the same distribution on outputs as calling
underlying.sample(value=pins)
, then stripping away the pinned
parts.
In the probability literature, this operation corresponds to
sampling from the do-calculus expression
underlying(Unpinned|do(Pinned = pins))
, where the assumed causal
structure of underlying
is given by its data dependence graph.
Note that this is not the same measure as is represented by the
log_prob
method of this class, which is why this method is not
called sample
.
Args | |
---|---|
sample_shape
|
Shape prefix to use when sampling. |
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
samples
|
unpinned parts sampled from the underlying distribution. |
unnormalized_log_prob
unnormalized_log_prob(
*args, **kwargs
)
Computes the log-density of this measure at the given point.
The measure on the space of unpinned values that is represented by
this object is not a probability measure, so the values of
unnormalized_log_prob
will in general not integrate to 1.
This method currently computes the same values as log_prob
, but
may be modified in the future to omit normalization constants.
The methods of JointDistributionPinned
(unnormalized_log_prob
,
sample_and_log_weight
, etc.) can be called by passing a single structure
of tensors, a sequence of tensor arguments, or using named args for each
part. For example:
tfde = tfp.experimental.distributions
# Given the following joint distribution:
jd = tfd.JointDistributionSequential([
tfd.Normal(0., 1., name='z'),
tfd.Normal(0., 1., name='y'),
lambda y, z: tfd.Normal(y + z, 1., name='x')
], validate_args=True)
# The following `__init__` styles are all permissible and produce
# `JointDistributionPinned` objects behaving identically.
PartialXY = collections.namedtuple('PartialXY', 'x,y')
PartialX = collections.namedtuple('PartialX', 'x')
OrderedDict = collections.OrderedDict
assert (tfde.JointDistributionPinned(jd, x=2.).pins ==
tfde.JointDistributionPinned(jd, x=2., z=None).pins ==
tfde.JointDistributionPinned(jd, dict(x=2.)).pins ==
tfde.JointDistributionPinned(jd, dict(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, OrderedDict(x=2.)).pins ==
tfde.JointDistributionPinned(jd, OrderedDict(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, PartialXY(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, PartialX(x=2.)).pins ==
tfde.JointDistributionPinned(jd, None, None, 2.).pins ==
tfde.JointDistributionPinned(jd, [None, None, 2.]).pins)
# (Notice that the `pins` attribute is always resolved to a `dict`.)
pinned = tfde.JointDistributionPinned(jd, x=2.)
pinned.dtype
# ==> [tf.float32, tf.float32]
z, y = sample = pinned.sample_unpinned()
# The following calling styles are all permissable and produce the exactly
# the same output.
PartialZY = collections.namedtuple('PartialZY', 'z,y')
assert (pinned.unnormalized_log_prob(sample) ==
pinned.unnormalized_log_prob(z, y) ==
pinned.unnormalized_log_prob(z=z, y=y) ==
pinned.unnormalized_log_prob(PartialZY(z=z, y=y)))
# These calling possibilities also imply that one can also use `*`
# expansion, if `sample` is a sequence:
pinned.unnormalized_log_prob(*sample)
# and similarly, if `sample` is a map, one can use `**` expansion:
pinned.unnormalized_log_prob(**sample)
Component distributions' names are resolved via jd._flat_resolve_names()
,
which is implemented by each JointDistribution
subclass (see subclass
documentation for details). Generally, for components where a name was
provided---either explicitly as the name
argument to a distribution or as
a key in a dict-valued JointDistribution, or implicitly, e.g., by the
argument name of a JointDistributionSequential
distribution-making
function---the provided name will be used. Otherwise the component will
receive a dummy name; these may change without warning and should not be
relied upon.
In general, return types of part-wise methods/properties are determined by
those of the underlying JointDistribution
's model type:
StructTuple
forJointDistributionCoroutine
, and forJointDistributionNamed
withnamedtuple
model type.collections.OrderedDict
forJointDistributionNamed
withOrderedDict
model type.dict
forJointDistributionNamed
withdict
model type.tuple
orlist
forJointDistributionSequential
.
pinned = tfde.JointDistributionPinned(
tfd.JointDistributionSequential(
[tfd.Exponential(1.), lambda s: tfd.Normal(0., s)]),
None, 1.2)
pinned.dtype # => [tf.float32]
pinned.unnormalized_log_prob([4.])
# ==> Tensor with shape `[]`.
lp = pinned.unnormalized_log_prob(4.)
# ==> Tensor with shape `[]`.
Notice that in the first call, [4.]
is interpreted as a list of one
scalar while in the second call the input is a scalar. Hence both inputs
result in identical scalar outputs. If we wanted to pass an explicit
vector to the Exponential
component---creating a vector-shaped batch
of unnormalized_log_prob
s---we could instead write
pinned.unnormalized_log_prob(np.array([4]))
.
Args | |
---|---|
*args
|
Positional arguments: a value structure or component values (see above). |
**kwargs
|
Keyword arguments: a value structure or component values
(see above). May also include name , specifying a Python string name
for ops generated by this method.
|
Returns | |
---|---|
log_prob
|
The joint log-probability of *xs or **kwargs
with the pinned parts. It is unnormalized with respect to *xs or
**kwargs .
|
unnormalized_log_prob_parts
unnormalized_log_prob_parts(
*args, **kwargs
)
Computes the log-probability of each part.
The measure on the space of unpinned values that is represented by
this object is not a probability measure, so the values produced
by unnormalized_log_prob_parts
will in general not integrate to
1.
This method currently computes the same values as
log_prob_parts
, but may be modified in the future to omit
normalization constants.
The methods of JointDistributionPinned
(unnormalized_log_prob
,
sample_and_log_weight
, etc.) can be called by passing a single structure
of tensors, a sequence of tensor arguments, or using named args for each
part. For example:
tfde = tfp.experimental.distributions
# Given the following joint distribution:
jd = tfd.JointDistributionSequential([
tfd.Normal(0., 1., name='z'),
tfd.Normal(0., 1., name='y'),
lambda y, z: tfd.Normal(y + z, 1., name='x')
], validate_args=True)
# The following `__init__` styles are all permissible and produce
# `JointDistributionPinned` objects behaving identically.
PartialXY = collections.namedtuple('PartialXY', 'x,y')
PartialX = collections.namedtuple('PartialX', 'x')
OrderedDict = collections.OrderedDict
assert (tfde.JointDistributionPinned(jd, x=2.).pins ==
tfde.JointDistributionPinned(jd, x=2., z=None).pins ==
tfde.JointDistributionPinned(jd, dict(x=2.)).pins ==
tfde.JointDistributionPinned(jd, dict(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, OrderedDict(x=2.)).pins ==
tfde.JointDistributionPinned(jd, OrderedDict(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, PartialXY(x=2., y=None)).pins ==
tfde.JointDistributionPinned(jd, PartialX(x=2.)).pins ==
tfde.JointDistributionPinned(jd, None, None, 2.).pins ==
tfde.JointDistributionPinned(jd, [None, None, 2.]).pins)
# (Notice that the `pins` attribute is always resolved to a `dict`.)
pinned = tfde.JointDistributionPinned(jd, x=2.)
pinned.dtype
# ==> [tf.float32, tf.float32]
z, y = sample = pinned.sample_unpinned()
# The following calling styles are all permissable and produce the exactly
# the same output.
PartialZY = collections.namedtuple('PartialZY', 'z,y')
assert (pinned.unnormalized_log_prob_parts(sample) ==
pinned.unnormalized_log_prob_parts(z, y) ==
pinned.unnormalized_log_prob_parts(z=z, y=y) ==
pinned.unnormalized_log_prob_parts(PartialZY(z=z, y=y)))
# These calling possibilities also imply that one can also use `*`
# expansion, if `sample` is a sequence:
pinned.unnormalized_log_prob_parts(*sample)
# and similarly, if `sample` is a map, one can use `**` expansion:
pinned.unnormalized_log_prob_parts(**sample)
Component distributions' names are resolved via jd._flat_resolve_names()
,
which is implemented by each JointDistribution
subclass (see subclass
documentation for details). Generally, for components where a name was
provided---either explicitly as the name
argument to a distribution or as
a key in a dict-valued JointDistribution, or implicitly, e.g., by the
argument name of a JointDistributionSequential
distribution-making
function---the provided name will be used. Otherwise the component will
receive a dummy name; these may change without warning and should not be
relied upon.
In general, return types of part-wise methods/properties are determined by
those of the underlying JointDistribution
's model type:
StructTuple
forJointDistributionCoroutine
, and forJointDistributionNamed
withnamedtuple
model type.collections.OrderedDict
forJointDistributionNamed
withOrderedDict
model type.dict
forJointDistributionNamed
withdict
model type.tuple
orlist
forJointDistributionSequential
.
pinned = tfde.JointDistributionPinned(
tfd.JointDistributionSequential(
[tfd.Exponential(1.), lambda s: tfd.Normal(0., s)]),
None, 1.2)
pinned.dtype # => [tf.float32]
pinned.unnormalized_log_prob_parts([4.])
# ==> Tensor with shape `[]`.
lp_parts = pinned.unnormalized_log_prob_parts(4.)
# ==> Tensor with shape `[]`.
Notice that in the first call, [4.]
is interpreted as a list of one
scalar while in the second call the input is a scalar. Hence both inputs
result in identical scalar outputs. If we wanted to pass an explicit
vector to the Exponential
component---creating a vector-shaped batch
of unnormalized_log_prob_parts
s---we could instead write
pinned.unnormalized_log_prob_parts(np.array([4]))
.
Args | |
---|---|
*args
|
Positional arguments: a value structure or component values (see above). |
**kwargs
|
Keyword arguments: a value structure or component values
(see above). May also include name , specifying a Python string name
for ops generated by this method.
|
Returns | |
---|---|
pinned
|
partial log-prob of each pinned part |
unpinned
|
partial log-prob of each unpinned part |
unpin
unpin(
*pinned_part_names
)
Unpins selected parts, returning a new instance.
Args | |
---|---|
*pinned_part_names
|
One or more str names of parts to unpin.
|
Returns | |
---|---|
jd
|
A joint distribution with the specified pins dropped. If all pins are dropped, the underlying joint distribution is returned. |