Attend the Women in ML Symposium on December 7 Register now


Wrapper to defer initialization of a tf.Module instance.

DeferredModule is a general-purpose mechanism for creating objects that are 'tape safe', meaning that computation occurs only when an instance method is called, not at construction. This ensures that method calls inside of a tf.GradientTape context will produce gradients to any underlying tf.Variables.


TFP's built-in Distributions and Bijectors are tape-safe by contract, but this does not extend to cases where computation is required to construct an object's parameters prior to initialization. For example, suppose we want to construct a Gamma distribution with a given mean and variance. In a naive implementation, we would convert these to the Gamma's native concentration and rate parameters when the distribution is constructed. Any future method calls would produce gradients to concentration and rate, but not to the underlying mean and variance:

mean, variance = tf.Variable(3.2), tf.Variable(9.1)
dist = tfd.Gamma(concentration=mean**2 / variance,
                 rate=mean / variance)

with tf.GradientTape() as tape:
  lp = dist.log_prob(5.0)
grads = tape.gradient(lp, [mean, variance])
# ==> `grads` are `[None, None]` !! :-(

To preserve the gradients, we can defer the parameter transformation using DeferredModule. The resulting object behaves just like a tfd.Gamma instance, however, instead of running the Gamma constructor just once, it internally applies the parameter transformation and constructs a new, temporary instance of tfd.Gamma on every method invocation. This ensures that all operations needed to compute a method's return value from any underlying variables are performed every time the method is invoked. A surrounding GradientTape context will therefore be able to trace the full computation.

def gamma_from_mean_and_variance(mean, variance, **kwargs):
  rate = mean / variance
  return tfd.Gamma(concentration=mean * rate, rate=rate, **kwargs)

mean, variance = tf.Variable(3.2), tf.Variable(9.1)
deferred_dist = tfp.experimental.util.DeferredModule(
  mean=mean,  # May be passed by position or by name.

with tf.GradientTape() as tape:
  lp = deferred_dist.log_prob(5.0)
grads = tape.gradient(lp, [mean, variance])
# ==> `grads` are defined!

Note that we could have achieved a similar effect by using tfp.util.DeferredTensor to individually defer the concentration and rate parameters. However, this would have been significantly more verbose, and would not share any computation between the two parameter transformations. In general, DeferredTensor is often idiomatic for simple transformations of a single value, while DeferredModule may be preferred for transformations that operate on multiple values and/or contain multiple steps.


Objects derived from a DeferredModule are no longer deferred, so they will not preserve gradients. For example, slicing into a deferred Distribution yields a new, concrete Distribution instance:

def normal_from_log_scale(scaled_loc, log_scale):
  return tfd.Normal(loc=5 * scaled_loc, scale=tf.exp(log_scale))

dist = tfp.experimental.util.DeferredModule(
  scaled_loc=tf.Variable([1., 2., 3.]),
  log_scale=tf.Variable([1., 1., 1.]))
dist.batch_shape  # ==> [3]
len(dist.trainable_variables)  # ==> 2

slice = dist[:2]  # Instantiates a new, non-deferred Distribution.
slice.batch_shape  # ==> [2]
len(slice.trainable_variables)  # ==> 0 (!)

# If needed, we could defer the slice with another layer of wrapping.
deferred_slice = tfp.experimental.util.DeferredModule(
  build_fn=lambda d: d[:2],
len(deferred_slice.trainable_variables)  # ==> 2

build_fn Python callable specifying a deferred transformation of the provided arguments. This must have signature module = build_fn(*args, **kwargs). The return value module is an instance of tf.Module.
*args Optional positional arguments to build_fn.
also_track Optional instance or structure of instances of tf.Variable and/or tf.Module, containing any additional trainable variables that the build_fn may access beyond the given args and kwargs. This ensures that such variables will be correctly tracked in self.trainable_variables. Default value: None.
**kwargs Optional keyword arguments to build_fn.

name Returns the name of this module as passed or determined in the ctor.

name_scope Returns a tf.name_scope instance for this class.
non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.
submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
list(b.submodules) == [c]
list(c.submodules) == []

trainable_variables Sequence of trainable variables owned by this module and its submodules.

variables Sequence of variables owned by this module and its submodules.



Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

method The method to wrap.

The original method wrapped such that it enters the module's name scope.


Return the absolute value of the argument.


Same as a + b.


Same as a & b.


bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.


View source


Same as b in a (note reversed operands).


View source


Same as a == b.


View source


Same as a // b.


Same as a >= b.


Same as a[b].


Same as a > b.


Same as ~a.


iter(iterable) -> iterator iter(callable, sentinel) -> iterator

Get an iterator from an object. In the first form, the argument must supply its own iterator, or be a sequence. In the second form, the callable is called until it returns the sentinel.


Same as a <= b.


Return the number of items in a container.


Same as a << b.


Same as a < b.


Same as a @ b.