orbit.StandardTrainer

Implements standard functionality on top of the AbstractTrainer API.

Inherits From: AbstractTrainer

This class structures the training "inner loop" roughly as follows:

train_loop_begin()
for _ in range(num_steps):
  train_step(train_iterator)
return train_loop_end()

Calls to train_loop_begin and train_loop_end are always done in eager mode, while the loop/train_step may be implemented using tf.while and/or tf.function, as determined by the options passed to __init__.

train_dataset A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset.
options An orbit.StandardTrainerOptions instance.

name Returns the name of this module as passed or determined in the ctor.

name_scope Returns a tf.name_scope instance for this class.
non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.
submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
True
list(b.submodules) == [c]
True
list(c.submodules) == []
True

train_dataset The current training dataset.
trainable_variables Sequence of trainable variables owned by this module and its submodules.

variables Sequence of variables owned by this module and its submodules.

Methods

create_train_loop_fn

View source

Creates a training loop from the current step function and options.

Returns
The train loop function, i.e. wrapper of multiple train steps.

train

View source

Implements num_steps steps of training.

Args
num_steps The number of training steps to run. This corresponds directly to the number of calls made to train_step.

Returns
The output of train_loop_end.

train_loop_begin

View source

Called once at the beginning of the training loop.

This method is always called in eager mode, and is a good place to reset metrics that accumulate values over multiple steps of training.

Note that this method is called before dataset iterator creation.

train_loop_end

View source

Called once at the end of the training loop.

This method is always called in eager mode, and is a good place to get metric results. The value returned from this function will be returned as-is from the train method implementation provided by StandardTrainer.

Returns
The function may return a dictionary of Tensors, which will be written to logs and as TensorBoard summaries. It can also be a nested dictionary, yielding a hierarchy of summary directories.

train_step

View source

Implements one step of training.

What a "step" consists of is up to the implementer. When using distribution strategies, the call to this method takes place in the "cross-replica context" for generality, to allow e.g. multiple iterator dequeues and calls to strategy.run.

Note that if use_tf_function=True, all the code inside train_step should be compatible with tf.function tracing (and in particular, any state modifications involving self should be avoided). In some cases, non- tf.function compatible code can be moved to train_loop_begin or train_loop_end, which always execute eagerly.

Args
iterator A tf.nest-compatible structure of tf.data.Iterator or DistributedIterator. The structure of this input matches the structure of train_dataset as passed to __init__.

with_name_scope

Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  @tf.Module.with_name_scope
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

Args
method The method to wrap.

Returns
The original method wrapped such that it enters the module's name scope.