tf.distribute.experimental.PreemptionCheckpointHandler

Preemption and error handler for synchronous training.

A PreemptionCheckpointHandler coordinates all workers to save a checkpoint upon receiving a preemption signal. It also helps disseminate application error messages accurately among the cluster. When a PreemptionCheckpointHandler object is created, it restores values from the latest checkpoint file if any exists.

Right after the initialization, the object starts to watch out for termination signal for any member in the cluster. If receiving a signal, the next time the worker executes PreemptionCheckpointHandler.run, the PreemptionCheckpointHandler will align all workers to save a checkpoint. Then, if an exit_fn is configured via tf.distribute.experimental.TerminationConfig, it will be invoked. Otherwise, the process will simply exit and later the platform should restart it.

For users of tf.distribute.MultiWorkerMirroredStrategy, the core API is PreemptionCheckpointHandler.run:

strategy = tf.distribute.MultiWorkerMirroredStrategy()

trained_epoch = tf.Variable(initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')
step_in_epoch = tf.Variable(initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='step_in_epoch')

with strategy.scope():
  dataset, model, optimizer = ...

  checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                   model=model,
                                   trained_epoch=trained_epoch,
                                   step_in_epoch=step_in_epoch)

  preemption_checkpoint_handler = tf.distribute.experimental.PreemptionCheckpointHandler(cluster_resolver, checkpoint, checkpoint_dir)

while trained_epoch.numpy() < NUM_EPOCH:

  while step_in_epoch.numpy() < STEPS_PER_EPOCH:

    # distributed_train_function contains a call to strategy.run.
    loss += preemption_checkpoint_handler.run(distributed_train_function, args=(next(iterator),))
    # For users of MultiWorkerMirroredStrategy, usually
    # STEPS_PER_TRAIN_FUNCTION = 1.
    step_in_epoch.assign_add(STEPS_PER_TRAIN_FUNCTION)
    ...

  epoch.assign_add(1)
  step_in_epoch.assign(0)

For users of tf.distribute.TPUStrategy, the core APIs are PreemptionCheckpointHandler.run and PreemptionCheckpointHandler.watch_preemption_scope:


strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)

# Rest of TPU init omitted, see documentation for TPUSTrategy.

with preemption_checkpoint_handler.watch_preemption_scope():
  while trained_epoch.numpy() < NUM_EPOCH:

    while step_in_epoch.numpy() < STEPS_PER_EPOCH:

      # distributed_train_function contains a call to strategy.run.
      loss += preemption_checkpoint_handler.run(distributed_train_function, args=(next(iterator),))

      # For users of TPUStrategy, usually STEPS_PER_TRAIN_FUNCTION >> 1 since
      # clustering multiple steps within a tf.function amortizes the overhead
      # of launching a multi-device function on TPU Pod.
      step_in_epoch.assign_add(STEPS_PER_TRAIN_FUNCTION)
      ...

    epoch.assign_add(1)
    step_in_epoch.assign(0)

Not all interruptions come with advance notice so that the PreemptionCheckpointHandler can handle them, e.g., those caused by hardware failure. For a user who saves checkpoints for these cases themselves outside the PreemptionCheckpointHandler, if they are using a tf.train.CheckpointManager, pass it as the checkpoint_or_checkpoint_manager argument to the PreemptionCheckpointHandler. If they do not have a tf.train.CheckpointManager but are directly working with tf.train.Checkpoint, we advise saving the checkpoints in the directory that's passed as the checkpoint_dir argument. In this way, at the program beginning, PreemptionCheckpointHandler can restore the latest checkpoint from the directory, no matter it's saved by the user themselves or saved by the PreemptionCheckpointHandler before preemption happens.

A note on the platform:

PreemptionCheckpointHandler can only handle the kind of termination with advance notice. For now, the API recognizes the termination signal for CPU, GPU, and TPU on Google Borg and CPU and GPU on the Google Cloud Platform. In these cases, PreemptionCheckpointHandler will automatically adopt the correct preemption/maintenance notification detection mechanism. Users of other platforms can configure a detection monitoring behavior through the tf.distribute.experimental.TerminationConfig. Customization for the exit behavior and grace period length could also be done here.

cluster_resolver a tf.distribute.cluster_resolver.ClusterResolver object. You may also obtain it through the cluster_resolver attribute of the distribution strategy in use.
checkpoint_or_checkpoint_manager a tf.train.CheckpointManager or a tf.train.Checkpoint. If you are using a tf.train.CheckpointManager to manage checkpoints outside the PreemptionCheckpointHandler for backup purpose as well, pass it as checkpoint_or_checkpoint_manager argument. Otherwise, pass a tf.train.Checkpoint and the PreemptionCheckpointHandler will create a tf.train.CheckpointManager to manage it in the checkpoint_dir.
checkpoint_dir a directory where the PreemptionCheckpointHandler saves and restores checkpoints. When a PreemptionCheckpointHandler is created, the latest checkpoint in the checkpoint_dir will be restored. (This is not needed if a tf.train.CheckpointManager instead of a tf.train.Checkpoint is passed as the checkpoint_or_checkpoint_manager argument.)
termination_config optional, a tf.distribute.experimental.TerminationConfig object to configure for a platform other than Google Borg or GCP.

Methods

run

View source

Runs a training function with error and preemption handling.

This function handles the preemption signal from any peer in the cluster by saving the training progress and exiting gracefully. It will also broadcase any program error encountered during the execution of distributed_train_function to all workers so that they can raise the same error.

The distributed_train_function argument should be a distributed train function (i.e., containing a call to tf.distribute.Strategy.run). For tf.distribute.MultiWorkerMirroredStrategy users, we recommend passing in a single-step distributed_train_function to PreemptionCheckpointHandler.run so that the checkpoint can be saved in time in case a preemption signal or maintenance notice is sent.

Besides the preemption and error handling part, PreemptionCheckpointHandler.run(distributed_train_function, *args, **kwargs) has the same effect and output as distributed_train_function(*args, **kwargs). distributed_train_function can return either some or no result. The following is a shortened example:


@tf.function
def distributed_train_step(iterator):
  # A distributed single-step training function.

  def step_fn(inputs):
    # A per-replica single-step training function.
    x, y = inputs
    ...
    return loss

  per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
  return strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

for epoch in range(preemption_handler.total_run_calls // STEPS_PER_EPOCH,
                   EPOCHS_TO_RUN):
  iterator = iter(multi_worker_dataset)
  total_loss = 0.0
  num_batches = 0

  for step in range(preemption_handler.total_run_calls % STEPS_PER_EPOCH,
                    STEPS_PER_EPOCH):
    total_loss += preemption_handler.run(distributed_train_step)
    num_batches += 1

  train_loss = total_loss / num_batches
  print('Epoch: %d, train_loss: %f.' %(epoch.numpy(), train_loss))

  train_accuracy.reset_states()

Args
distributed_train_function A (single-step) distributed training function.
*args args for distributed_train_function.
**kwargs kwargs for distributed_train_function.

Raises
Program error encountered by any member in the cluster while executing the distributed_train_function, or any error from the program error propagation process.

Returns
Result of running the distributed_train_function.

save_checkpoint_if_preempted

View source

Saves a checkpoint if a preemption signal has been made available.

This is an alternative API for PreemptionCheckpointHandler.run and PreemptionCheckpointHandler.watch_preemption_scope. This method works for both tf.distribute.MultiWorkerMirroredStrategy and tf.distribute.TPUStrategy. However, for TPUStrategy, this method will add a synchronization point between workers and the coordinator and thus may have performance implication. If this is a concern, use the combination of PreemptionCheckpointHandler.watch_preemption_scope and PreemptionCheckpointHandler.run instead.

strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)
# initialization omitted

with strategy.scope():
  # Save in the checkpoint.
  trained_step = tf.Variable(initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='trained_step', aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)

  checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory, max_to_keep=1)
  preemption_handler = tf.distribute.experimental.PreemptionCheckpointHandler(cluster_resolver, checkpoint_manager)

while trained_step.numpy() < NUM_STEPS:
  # Train STEPS_IN_FUNCTION steps at once.
  train_multi_step_function()
  trained_step.assign_add(STEPS_IN_FUNCTION)
  preemption_handler.save_checkpoint_if_preempted()

Args
*args args for tf.train.CheckpointManager.save() to save checkpoint.
**kwargs kwargs for tf.train.CheckpointManager.save() to save.

watch_preemption_scope

View source

Syncs error and maybe save checkpoint for usage with TPUStrategy.

Example usage:

with preemption_checkpoint_handler.watch_preemption_scope():
  while trained_step.numpy() < NUM_STEPS:

    # distributed_train_function contains a call to strategy.run.
    loss += preemption_checkpoint_handler.run(distributed_train_function, args=(next(iterator),))
    trained_step.assign_add(STEPS_PER_TRAIN_FUNCTION)

In this workflow, PreemptionCheckpointHandler.run will flag preemption signal received, and watch_preemption_scope will handle the preemption signal by saving a checkpoint and then either exit to restart or execute a user-passed exit_fn in tf.distribute.experimental.TerminationConfig. If no preemption signal is received during execution of ops and function inside the scope, watch_preemption_scope ensures the completion of all async op and function execution when exiting and will raises exceptions if async execution results in an error state.

Yields
None