Base neural network module class.

A module is a named container for tf.Variables, other tf.Modules and functions which apply to user input. For example a dense layer in a neural network might be implemented as a tf.Module:

class Dense(tf.Module):
  def __init__(self, input_dim, output_size, name=None):
    self.w = tf.Variable(
      tf.random.normal([input_dim, output_size]), name='w')
    self.b = tf.Variable(tf.zeros([output_size]), name='b')
  def __call__(self, x):
    y = tf.matmul(x, self.w) + self.b
    return tf.nn.relu(y)

You can use the Dense layer as you would expect:

d = Dense(input_dim=3, output_size=2)
d(tf.ones([1, 3]))
<tf.Tensor: shape=(1, 2), dtype=float32, numpy=..., dtype=float32)>

By subclassing tf.Module instead of object any tf.Variable or tf.Module instances assigned to object properties can be collected using the variables, trainable_variables or submodules property:

    (<tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=...,
    <tf.Variable 'w:0' shape=(3, 2) dtype=float32, numpy=..., dtype=float32)>)

Subclasses of tf.Module can also take advantage of the _flatten method which can be used to implement tracking of any other types.

All tf.Module classes have an associated tf.name_scope which can be used to group operations in TensorBoard and create hierarchies for variable names which can help with debugging. We suggest using the name scope when creating nested submodules/parameters or for forward methods whose graph you might want to inspect in TensorBoard. You can enter the name scope explicitly using with self.name_scope: or you can annotate methods (apart from __init__) with @tf.Module.with_name_scope.

class MLP(tf.Module):
  def __init__(self, input_size, sizes, name=None):
    self.layers = []
    with self.name_scope:
      for size in sizes:
        self.layers.append(Dense(input_dim=input_size, output_size=size))
        input_size = size
  def __call__(self, x):
    for layer in self.layers:
      x = layer(x)
    return x
module = MLP(input_size=5, sizes=[5, 5])
(<tf.Variable 'mlp/b:0' shape=(5,) dtype=float32, numpy=..., dtype=float32)>,
<tf.Variable 'mlp/w:0' shape=(5, 5) dtype=float32, numpy=...,
<tf.Variable 'mlp/b:0' shape=(5,) dtype=float32, numpy=..., dtype=float32)>,
<tf.Variable 'mlp/w:0' shape=(5, 5) dtype=float32, numpy=...,



View source