Wrapper to defer initialization of a tf.Module instance.

Inherits From: SpecialMethods

DeferredModule is a general-purpose mechanism for creating objects that are 'tape safe', meaning that computation occurs only when an instance method is called, not at construction. This ensures that method calls inside of a tf.GradientTape context will produce gradients to any underlying tf.Variables.


TFP's built-in Distributions and Bijectors are tape-safe by contract, but this does not extend to cases where computation is required to construct an object's parameters prior to initialization. For example, suppose we want to construct a Gamma distribution with a given mean and variance. In a naive implementation, we would convert these to the Gamma's native concentration and rate parameters when the distribution is constructed. Any future method calls would produce gradients to concentration and rate, but not to the underlying mean and variance:

mean, variance = tf.Variable(3.2), tf.Variable(9.1)
dist = tfd.Gamma(concentration=mean**2 / variance,
                 rate=mean / variance)

with tf.GradientTape() as tape:
  lp = dist.log_prob(5.0)
grads = tape.gradient(lp, [mean, variance])
# ==> `grads` are `[None, None]` !! :-(

To preserve the gradients, we can defer the parameter transformation using DeferredModule. The resulting object behaves just like a tfd.Gamma instance, however, instead of running the Gamma constructor just once, it internally applies the parameter transformation and constructs a new, temporary instance of tfd.Gamma on every method invocation. This ensures that all operations needed to compute a method's return value from any underlying variables are performed every time the method is invoked. A surrounding GradientTape context will therefore be able to trace the full computation.

def gamma_params_from_mean_and_variance(mean, variance, **kwargs):
  rate = mean / variance
  return dict(concentration=mean * rate, rate=rate, **kwargs)

mean, variance = tf.Variable(3.2), tf.Variable(9.1)
deferred_dist = tfp.experimental.util.DeferredModule(
  mean=mean,  # May be passed by position or by name.

with tf.GradientTape() as tape:
  lp = deferred_dist.log_prob(5.0)
grads = tape.gradient(lp, [mean, variance])
# ==> `grads` are defined!

Note that we could have achieved a similar effect by using tfp.util.DeferredTensor to individually defer the concentration and rate parameters. However, this would have been significantly more verbose, and would not share any computation between the two parameter transformations. In general, DeferredTensor is often idiomatic for simple transformations of a single value, while DeferredModule may be preferred for transformations that operate on multiple values and/or contain multiple steps.


Objects derived from a DeferredModule are no longer deferred, so they will not preserve gradients. For example, slicing into a deferred Distribution yields a new, concrete Distribution instance:

dist = tfp.experimental.util.DeferredModule(
  args_fn=lambda scaled_loc, log_scale: (5 * scaled_loc, tf.exp(log_scale)),
  scaled_loc=tf.Variable([1., 2., 3.]),
  log_scale=tf.Variable([1., 1., 1.]))
dist.batch_shape  # ==> [3]
len(dist.trainable_variables)  # ==> 2

slice = dist[:2]  # Instantiates a new, non-deferred Distribution.
slice.batch_shape  # ==> [2]
len(slice.trainable_variables)  # ==> 0 (!)

# If needed, we could defer the slice with another layer of wrapping.
deferred_slice = tfp.experimental.util.DeferredModule(
  base_class=lambda d: d[:2],
  args_fn=lambda d: d,
len(deferred_slice.trainable_variables)  # ==> 2

base_class Python type or callable such that base_class(**args_fn(...)) is an instance of tf.Module---for example, a TFP Distribution or Bijector.
args_fn Python callable specifying a deferred transformation of the provided arguments. This must have signature base_class_init_args = args_fn(*args, **kwargs). The return value base_class_init_args may be either a dictionary or an iterable (list/tuple), in which case the class will be initialized as base_class(**base_class_init_args) or base_class(*base_class_init_args), respectively.
*args Optional positional arguments to args_fn.
**kwargs Optional keyword arguments to args_fn.

name Returns the name of this module as passed or determined in the ctor.

name_scope Returns a tf.name_scope instance for this class.
submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
list(b.submodules) == [c]
list(c.submodules) == []

trainable_variables Sequence of trainable variables owned by this module and its submodules.

variables Sequence of variables owned by this module and its submodules.



Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

method The method to wrap.

The original method wrapped such that it enters the module's name scope.


Return the absolute value of the argument.


Same as a + b.


Same as a & b.


bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.


View source


Same as b in a (note reversed operands).


View source


Same as a == b.


View source


Same as a // b.


Same as a >= b.


Same as a[b].


Same as a > b.


Same as ~a.


iter(iterable) -> iterator iter(callable, sentinel) -> iterator

Get an iterator from an object. In the first form, the argument must supply its own iterator, or be a sequence. In the second form, the callable is called until it returns the sentinel.


Same as a <= b.


Return the number of items in a container.


Same as a < b.


Same as a @ b.


Same as a % b.


Same as a * b.


Same as a != b.


Same as -a.


Same as a | b.


Same as +a.


Equivalent to xy (with two arguments) or xy % z (with three arguments)

Some types, such as ints, are able to use a more efficient algorithm when invoked using the three argument form.


Same as a + b.


Same as a & b.


Same as a // b.


Same as a @ b.


Same as a % b.


Same as a * b.


Same as a | b.


Equivalent to xy (with two arguments) or xy % z (with three arguments)

Some types, such as ints, are able to use a more efficient algorithm when invoked using the three argument form.


Same as a - b.


Same as a / b.


Same as a ^ b.


Same as a - b.


Same as a / b.


Same as a ^ b.