View source on GitHub |
Wrapper to defer initialization of a tf.Module
instance.
tfp.experimental.util.DeferredModule(
build_fn, *args, also_track=None, **kwargs
)
DeferredModule
is a general-purpose mechanism for creating objects that are
'tape safe', meaning that computation occurs only when an instance
method is called, not at construction. This ensures that method calls inside
of a tf.GradientTape
context will produce gradients to any underlying
tf.Variable
s.
Examples
TFP's built-in Distributions and Bijectors are tape-safe by contract, but
this does not extend to cases where computation is required
to construct an object's parameters prior to initialization.
For example, suppose we want to construct a Gamma
distribution with a given mean and variance. In a naive implementation,
we would convert these to the Gamma's native concentration
and
rate
parameters when the distribution is constructed. Any future method
calls would produce gradients to concentration
and rate
, but not to the
underlying mean and variance:
mean, variance = tf.Variable(3.2), tf.Variable(9.1)
dist = tfd.Gamma(concentration=mean**2 / variance,
rate=mean / variance)
with tf.GradientTape() as tape:
lp = dist.log_prob(5.0)
grads = tape.gradient(lp, [mean, variance])
# ==> `grads` are `[None, None]` !! :-(
To preserve the gradients, we can defer the parameter transformation using
DeferredModule
. The resulting object behaves just like a
tfd.Gamma
instance, however, instead of running the Gamma
constructor just
once, it internally applies the parameter transformation and constructs a
new, temporary instance of tfd.Gamma
on every method invocation.
This ensures that all operations needed to compute a method's return value
from any underlying variables are performed every time the method is invoked.
A surrounding GradientTape
context will therefore be able to trace the full
computation.
def gamma_from_mean_and_variance(mean, variance, **kwargs):
rate = mean / variance
return tfd.Gamma(concentration=mean * rate, rate=rate, **kwargs)
mean, variance = tf.Variable(3.2), tf.Variable(9.1)
deferred_dist = tfp.experimental.util.DeferredModule(
build_fn=gamma_from_mean_and_variance,
mean=mean, # May be passed by position or by name.
variance=variance)
with tf.GradientTape() as tape:
lp = deferred_dist.log_prob(5.0)
grads = tape.gradient(lp, [mean, variance])
# ==> `grads` are defined!
Note that we could have achieved a similar effect by using
tfp.util.DeferredTensor
to individually defer the concentration
and rate
parameters. However, this would have been significantly more verbose, and
would not share any computation between the two parameter transformations.
In general, DeferredTensor
is often idiomatic for simple transformations of
a single value, while DeferredModule
may be preferred for transformations
that operate on multiple values and/or contain multiple steps.
Caveats
Objects derived from a DeferredModule
are no longer deferred, so
they will not preserve gradients. For example, slicing into a deferred
Distribution yields a new, concrete Distribution instance:
def normal_from_log_scale(scaled_loc, log_scale):
return tfd.Normal(loc=5 * scaled_loc, scale=tf.exp(log_scale))
dist = tfp.experimental.util.DeferredModule(
build_fn=normal_from_log_scale,
scaled_loc=tf.Variable([1., 2., 3.]),
log_scale=tf.Variable([1., 1., 1.]))
dist.batch_shape # ==> [3]
len(dist.trainable_variables) # ==> 2
slice = dist[:2] # Instantiates a new, non-deferred Distribution.
slice.batch_shape # ==> [2]
len(slice.trainable_variables) # ==> 0 (!)
# If needed, we could defer the slice with another layer of wrapping.
deferred_slice = tfp.experimental.util.DeferredModule(
build_fn=lambda d: d[:2],
d=dist)
len(deferred_slice.trainable_variables) # ==> 2
Args | |
---|---|
build_fn
|
Python callable specifying a deferred transformation of the
provided arguments. This must have signature
module = build_fn(*args, **kwargs) . The return value module is an
instance of tf.Module .
|
*args
|
Optional positional arguments to build_fn .
|
also_track
|
Optional instance or structure of instances of tf.Variable
and/or tf.Module , containing any additional trainable variables that
the build_fn may access beyond the given args and kwargs . This
ensures that such variables will be correctly tracked in
self.trainable_variables .
Default value: None .
|
**kwargs
|
Optional keyword arguments to build_fn .
|
Attributes | |
---|---|
name
|
Returns the name of this module as passed or determined in the ctor. |
name_scope
|
Returns a tf.name_scope instance for this class.
|
non_trainable_variables
|
Sequence of non-trainable variables owned by this module and its submodules. |
submodules
|
Sequence of all sub-modules.
Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).
|
trainable_variables
|
Sequence of trainable variables owned by this module and its submodules. |
variables
|
Sequence of variables owned by this module and its submodules. |
Methods
with_name_scope
@classmethod
with_name_scope( method )
Decorator to automatically enter the module name scope.
class MyModule(tf.Module):
@tf.Module.with_name_scope
def __call__(self, x):
if not hasattr(self, 'w'):
self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
return tf.matmul(x, self.w)
Using the above module would produce tf.Variable
s and tf.Tensor
s whose
names included the module name:
mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>
Args | |
---|---|
method
|
The method to wrap. |
Returns | |
---|---|
The original method wrapped such that it enters the module's name scope. |
__abs__
__abs__()
Return the absolute value of the argument.
__add__
__add__(
b, /
)
Same as a + b.
__and__
__and__(
b, /
)
Same as a & b.
__bool__
__bool__()
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
__call__
__call__(
*args, **kwargs
)
__contains__
__contains__(
b, /
)
Same as b in a (note reversed operands).
__enter__
__enter__()
__eq__
__eq__(
b, /
)
Same as a == b.
__exit__
__exit__(
exc_type, exc_value, traceback
)
__floordiv__
__floordiv__(
b, /
)
Same as a // b.
__ge__
__ge__(
b, /
)
Same as a >= b.
__getitem__
__getitem__(
b, /
)
Same as a[b].
__gt__
__gt__(
b, /
)
Same as a > b.
__invert__
__invert__()
Same as ~a.
__iter__
__iter__()
iter(iterable) -> iterator iter(callable, sentinel) -> iterator
Get an iterator from an object. In the first form, the argument must supply its own iterator, or be a sequence. In the second form, the callable is called until it returns the sentinel.
__le__
__le__(
b, /
)
Same as a <= b.
__len__
__len__()
Return the number of items in a container.
__lshift__
__lshift__(
b, /
)
Same as a << b.
__lt__
__lt__(
b, /
)
Same as a < b.
__matmul__
__matmul__(
b, /
)
Same as a @ b.
__mod__
__mod__(
b, /
)
Same as a % b.
__mul__
__mul__(
b, /
)
Same as a * b.
__ne__
__ne__(
b, /
)
Same as a != b.
__neg__
__neg__()
Same as -a.
__or__
__or__(
b, /
)
Same as a | b.
__pos__
__pos__()
Same as +a.
__pow__
__pow__(
exp, mod=None
)
Equivalent to baseexp with 2 arguments or baseexp % mod with 3 arguments
Some types, such as ints, are able to use a more efficient algorithm when invoked using the three argument form.
__radd__
__radd__(
b, /
)
Same as a + b.
__rand__
__rand__(
b, /
)
Same as a & b.
__rfloordiv__
__rfloordiv__(
b, /
)
Same as a // b.
__rlshift__
__rlshift__(
b, /
)
Same as a << b.
__rmatmul__
__rmatmul__(
b, /
)
Same as a @ b.
__rmod__
__rmod__(
b, /
)
Same as a % b.
__rmul__
__rmul__(
b, /
)
Same as a * b.
__ror__
__ror__(
b, /
)
Same as a | b.
__rpow__
__rpow__(
exp, mod=None
)
Equivalent to baseexp with 2 arguments or baseexp % mod with 3 arguments
Some types, such as ints, are able to use a more efficient algorithm when invoked using the three argument form.
__rrshift__
__rrshift__(
b, /
)
Same as a >> b.
__rshift__
__rshift__(
b, /
)
Same as a >> b.
__rsub__
__rsub__(
b, /
)
Same as a - b.
__rtruediv__
__rtruediv__(
b, /
)
Same as a / b.
__rxor__
__rxor__(
b, /
)
Same as a ^ b.
__sub__
__sub__(
b, /
)
Same as a - b.
__truediv__
__truediv__(
b, /
)
Same as a / b.
__xor__
__xor__(
b, /
)
Same as a ^ b.