Implements standard functionality on top of the AbstractTrainer API.

Inherits From: AbstractTrainer

This class structures the training "inner loop" roughly as follows:

for _ in range(num_steps):
return train_loop_end()

Calls to train_loop_begin and train_loop_end are always done in eager mode, while the loop/train_step may be implemented using tf.while and/or tf.function, as determined by the options passed to __init__.

train_dataset A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset.
options An orbit.StandardTrainerOptions instance.

name Returns the name of this module as passed or determined in the ctor.

name_scope Returns a tf.name_scope instance for this class.
non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.
submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
list(b.submodules) == [c]
list(c.submodules) == []

train_dataset The current training dataset.
trainable_variables Sequence of trainable variables owned by this module and its submodules.

variables Sequence of variables owned by this module and its submodules.



View source

Creates a training loop from the current step function and options.

The train loop function, i.e. wrapper of multiple train steps.


View source

Implements num_steps steps of training.

num_steps The number of training steps to run. This corresponds directly to the number of calls made to train_step.

The output of train_loop_end.


View source

Called once at the beginning of the training loop.

This method is always called in eager mode, and is a good place to reset metrics that accumulate values over multiple steps of training.

Note that this method is called before dataset iterator creation.


View source

Called once at the end of the training loop.

This method is always called in eager mode, and is a good place to get metric results. The value returned from this function will be returned as-is from the train method implementation provided by StandardTrainer.

The function may return a dictionary of Tensors, which will be written to logs and as TensorBoard summaries. It can also be a nested dictionary, yielding a hierarchy of summary directories.


View source

Implements one step of training.

What a "step" consists of is up to the implementer. When using distribution strategies, the call to this method takes place in the "cross-replica context" for generality, to allow e.g. multiple iterator dequeues and calls to strategy.run.

Note that if use_tf_function=True, all the code inside train_step should be compatible with tf.function tracing (and in particular, any state modifications involving self should be avoided). In some cases, non- tf.function compatible code can be moved to train_loop_begin or train_loop_end, which always execute eagerly.

iterator A tf.nest-compatible structure of tf.data.Iterator or DistributedIterator. The structure of this input matches the structure of train_dataset as passed to __init__.


Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

method The method to wrap.

The original method wrapped such that it enters the module's name scope.