orbit.Controller

Class that controls the outer loop of model training and evaluation.

Orbit divides training and evaluation into "inner" and "outer" loops. Inner loops are implemented by users in the form of AbstractTrainer and AbstractEvaluator subclasses, and define how to run a given number of training or evaluation steps. The outer loop is provided by this Controller, and interleaves calls to the user-provided inner loops with additional actions such as saving checkpoints, running evaluations, writing summaries, as well as (optionally) user provided Actions (see below).

There are four top-level "outer loops" provided:

  • train, which trains until a specified number of global steps is reached;
  • evaluate, for one-off model evaluation;
  • train_and_evaluate, for interleaved training and evaluation;
  • evaluate_continuously, for monitoring a given directory and running evaluations on new model checkpoints.

While this class attempts to provide out-of-the-box solutions for common training and evaluation use cases, the internal details and method implementations are also intended to be simple enough to make subclassing or other custom outer loop implementations easy to achieve.

Some additional customization can be achieved by supplying train_actions or eval_actions when constructing the Controller. Actions arbitrary callables that are applied by the Controller to the output of train steps (after each inner loop of steps_per_loop steps) or an evaluation. This provides a hook mechanism, enabling things like reporting metrics to Vizier, model exporting, additional logging, etc. See the orbit.actions package for a small handful of predefined actions and some utility classes that may be useful in defining your own.

global_step An integer tf.Variable storing the global training step number. Usually this can be obtained from the iterations property of the model's optimizer (e.g. trainer.optimizer.iterations). In cases where multiple optimizers are used, or if one model "step" corresponds to more than one update to model parameters, users can create and increment their own global step variable as well. In this case it is recommended to create the tf.Variable inside the distribution strategy scope, with aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA (see also orbit.utils.create_global_step()).
trainer An instance of orbit.AbstractTrainer, which implements the inner training loop.
evaluator An instance of orbit.AbstractEvaluator, which implements evaluation.
strategy An instance of tf.distribute.Strategy. If not provided, the strategy will be initialized from the current in-scope strategy using tf.distribute.get_strategy().
train_actions Optional orbit.Actions to call after each block of steps_per_loop training steps are run. These will be called with the output of trainer.train.
eval_actions Optional orbit.Actions to call after each evaluation. These will be called with the output of evaluator.evaluate.
steps_per_loop Optional integer to indicate the number of steps to run in each inner loop of training (passed as the num_steps parameter of trainer.train). It can be also a callable which takes the current global step value as input and returns the number of steps to run as output.
checkpoint_manager An instance of tf.train.CheckpointManager. If provided and there are checkpoints in the associated model directory, the model will be restored from the most recent checkpoint inside this __init__ method. If not provided, the Controller will not automatically save to or restore from checkpoints.
enable_async_checkpointing Optional bool indicating whether to enable async checkpoint saving.
summary_interval Step interval for training summaries. Note that this argument only applies to tf.summary calls inside the trainer.train function. Summaries written by the Controller (specifically "steps_per_second" and output from the trainer.train method) will always be enabled unless the summary_dir parameter is None. If set, the value must be divisible by steps_per_loop.
summary_dir The directory to write summaries to. To use the same directory as for checkpointing, pass checkpoint_manager.directory. If None, no training summaries will be written.
eval_summary_dir The directory to write eval summaries to. If None, it will be set to summary_dir. If both summary_dir and eval_summary_dir are None, no eval summaries will be written.
summary_manager Instance of the summary manager. If set, the summary_dir will be ignored. Otherwise the summary manager will be created internally for TensorBoard summaries by default from the summary_dir.
eval_summary_manager Instance of the eval summary manager. If set, the eval_summary_dir will be ignored. Otherwise the eval summary manager will be created internally for TensorBoard summaries by default from the eval_summary_dir.

ValueError If both trainer and evaluator are None.
ValueError If steps_per_loop is not a positive integer or a callable.
ValueError If summary_interval is not a positive integer or is not divisible by steps_per_loop.

steps_per_loop Returns current steps_per_loop value in a training loop.

Methods

evaluate

View source

Runs evaluation for the given number of steps.

This method calls self.evaluator.evaluate(steps), then writes the returned summaries (if any).

Args
steps The number of evaluation steps to run. The value -1 is reserved as a special sentinel to indicate a "complete" evaluation that runs until the underlying dataset is exhausted. Support for this is dependent on the specific evaluator being used.

Returns
The evaluation results as a dictionary mapping names to NumPy values.

Raises
ValueError If evaluator was not provided to Controller.init.
ValueError If no checkpoint is present in checkpoint_manager.directory.
ValueError If steps is not a positive value or -1.

evaluate_continuously

View source

Continuously monitors a directory and evaluates new checkpoints in it.

This method continuously monitors a directory as specified by this Controller's CheckpointManager init arg and runs evaluation on the checkpoints found there.

Args
steps The number of steps to run when evaluating. If -1, this method will evaluate over the entire evaluation dataset.
timeout The maximum number of seconds to wait between checkpoints. See tf.train.checkpoints_iterator documentation.
timeout_fn Optional callable to call after a timeout. If the function returns True, then it means that no new checkpoints will be generated and the iterator will exit.

Returns
The evaluation results as a dictionary mapping names to NumPy values.

Raises
ValueError If no checkpoint found in self.checkpoint_manager.directory.
ValueError If evaluator was not provided as a controller init arg.

restore_checkpoint

View source

Restores the model from a checkpoint.

Args
checkpoint_path An optional string specifying the checkpoint path to restore from. If None, will restore from the most recent checkpoint (or initialize the model using a custom init_fn if no checkpoints can be found) using self.checkpoint_manager.restore_or_initialize().

Returns
The path to the restored checkpoint if a restore happened, or None if no restore occurred.

save_checkpoint

View source

Saves the model to a checkpoint.

This method will save a checkpoint containing the current state of the model.

Raises
ValueError If no checkpoint_manager was provided to Controller.init.

train

View source

Runs training until the specified global step count has been reached.

This method makes calls to self.trainer.train() until the global step count is equal to steps. It will additionally save checkpoints (if a CheckpointManager was passed to Controller.init) and summarize training output (if summary_dir is set).

When async checkpointing is enabled, a sync is triggered at the end of this method to make sure any ongoing async checkpoint saving is finished before returning.

Args
steps The global step count to train up to.
checkpoint_at_completion Whether to save a checkpoint when this method returns (regardless of the checkpointing interval). Defaults to True.

train_and_evaluate

View source

Runs interleaved training and evaluation.

This method interleaves calls to self.train() and self.evaluate(), training the model until the global step count equals train_steps, and running an evaluation for eval_steps every eval_interval training steps. In addition, this method will run a final evaluation at the end of the training sequence.

When async checkpointing is enabled, a sync is triggered at the end of this method to make sure any ongoing async checkpoint saving is finished before returning.

Args
train_steps The global step count to train up to.
eval_steps The number of steps to run during an evaluation. If -1, this method will evaluate over the entire evaluation dataset.
eval_interval The number of training steps to run between evaluations. If set, training will always stop every eval_interval steps, even if this results in a shorter inner loop than specified by steps_per_loop setting. If None, evaluation will only be performed after training is complete.

Returns
The evaluation results as a dictionary mapping names to NumPy values.