|View source on GitHub|
Wrapper to defer initialization of a
tfp.experimental.util.DeferredModule( build_fn, *args, also_track=None, **kwargs )
DeferredModule is a general-purpose mechanism for creating objects that are
'tape safe', meaning that computation occurs only when an instance
method is called, not at construction. This ensures that method calls inside
tf.GradientTape context will produce gradients to any underlying
TFP's built-in Distributions and Bijectors are tape-safe by contract, but
this does not extend to cases where computation is required
to construct an object's parameters prior to initialization.
For example, suppose we want to construct a Gamma
distribution with a given mean and variance. In a naive implementation,
we would convert these to the Gamma's native
rate parameters when the distribution is constructed. Any future method
calls would produce gradients to
rate, but not to the
underlying mean and variance:
mean, variance = tf.Variable(3.2), tf.Variable(9.1) dist = tfd.Gamma(concentration=mean**2 / variance, rate=mean / variance) with tf.GradientTape() as tape: lp = dist.log_prob(5.0) grads = tape.gradient(lp, [mean, variance]) # ==> `grads` are `[None, None]` !! :-(
To preserve the gradients, we can defer the parameter transformation using
DeferredModule. The resulting object behaves just like a
tfd.Gamma instance, however, instead of running the
Gamma constructor just
once, it internally applies the parameter transformation and constructs a
new, temporary instance of
tfd.Gamma on every method invocation.
This ensures that all operations needed to compute a method's return value
from any underlying variables are performed every time the method is invoked.
GradientTape context will therefore be able to trace the full
def gamma_from_mean_and_variance(mean, variance, **kwargs): rate = mean / variance return tfd.Gamma(concentration=mean * rate, rate=rate, **kwargs) mean, variance = tf.Variable(3.2), tf.Variable(9.1) deferred_dist = tfp.experimental.util.DeferredModule( build_fn=gamma_from_mean_and_variance, mean=mean, # May be passed by position or by name. variance=variance) with tf.GradientTape() as tape: lp = deferred_dist.log_prob(5.0) grads = tape.gradient(lp, [mean, variance]) # ==> `grads` are defined!
Note that we could have achieved a similar effect by using
tfp.util.DeferredTensor to individually defer the
parameters. However, this would have been significantly more verbose, and
would not share any computation between the two parameter transformations.
DeferredTensor is often idiomatic for simple transformations of
a single value, while
DeferredModule may be preferred for transformations
that operate on multiple values and/or contain multiple steps.
Objects derived from a
DeferredModule are no longer deferred, so
they will not preserve gradients. For example, slicing into a deferred
Distribution yields a new, concrete Distribution instance:
def normal_from_log_scale(scaled_loc, log_scale): return tfd.Normal(loc=5 * scaled_loc, scale=tf.exp(log_scale)) dist = tfp.experimental.util.DeferredModule( build_fn=normal_from_log_scale, scaled_loc=tf.Variable([1., 2., 3.]), log_scale=tf.Variable([1., 1., 1.])) dist.batch_shape # ==>  len(dist.trainable_variables) # ==> 2 slice = dist[:2] # Instantiates a new, non-deferred Distribution. slice.batch_shape # ==>  len(slice.trainable_variables) # ==> 0 (!) # If needed, we could defer the slice with another layer of wrapping. deferred_slice = tfp.experimental.util.DeferredModule( build_fn=lambda d: d[:2], d=dist) len(deferred_slice.trainable_variables) # ==> 2
Python callable specifying a deferred transformation of the
provided arguments. This must have signature
Optional positional arguments to
Optional instance or structure of instances of
Optional keyword arguments to
||Returns the name of this module as passed or determined in the ctor.|
||Sequence of non-trainable variables owned by this module and its submodules.|
Sequence of all sub-modules.
Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).
||Sequence of trainable variables owned by this module and its submodules.|
||Sequence of variables owned by this module and its submodules.|
with_name_scope( method )
Decorator to automatically enter the module name scope.
def __call__(self, x):
if not hasattr(self, 'w'):
self.w = tf.Variable(tf.random.normal([x.shape, 3]))
return tf.matmul(x, self.w)
mod = MyModule()
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
||The method to wrap.|
|The original method wrapped such that it enters the module's name scope.|
Return the absolute value of the argument.
__add__( b, / )
Same as a + b.
__and__( b, / )
Same as a & b.
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
__call__( *args, **kwargs )
__contains__( b, / )
Same as b in a (note reversed operands).
__eq__( b, / )
Same as a == b.
__exit__( exc_type, exc_value, traceback )
__floordiv__( b, / )
Same as a // b.
__ge__( b, / )
Same as a >= b.
__getitem__( b, / )
Same as a[b].
__gt__( b, / )
Same as a > b.
Same as ~a.
iter(iterable) -> iterator iter(callable, sentinel) -> iterator
Get an iterator from an object. In the first form, the argument must supply its own iterator, or be a sequence. In the second form, the callable is called until it returns the sentinel.
__le__( b, / )
Same as a <= b.
Return the number of items in a container.
__lshift__( b, / )
Same as a << b.
__lt__( b, / )
Same as a < b.
__matmul__( b, / )
Same as a @ b.