Preemption and error handler for synchronous training.
tf.distribute.experimental.PreemptionCheckpointHandler(
cluster_resolver,
checkpoint_or_checkpoint_manager,
checkpoint_dir=None,
termination_config=None
)
A PreemptionCheckpointHandler
coordinates all workers to save a checkpoint
upon receiving a preemption signal. It also helps disseminate application
error messages accurately among the cluster. When a
PreemptionCheckpointHandler
object is created, it restores values from
the latest checkpoint file if any exists.
Right after the initialization, a thread starts to watch out for a termination
signal for any member in the cluster. If receiving a signal, the next time the
worker enters a PreemptionCheckpointHandler.run
call, the
PreemptionCheckpointHandler
will align the worker steps to save a checkpoint
and maybe exit -- depending on the exit_fn
in
tf.distribute.experimental.TerminationConfig
.
Example usage:
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
dataset, model, optimizer = ...
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
preemption_handler = tf.distribute.experimental.PreemptionCheckpointHandler(cluster_resolver, checkpoint, checkpoint_directory)
# preemption_handler.total_run_calls will be restored to its saved value if
# training is restored after interruption.
for epoch in range(preemption_handler.total_run_calls // STEPS_PER_EPOCH, num_epochs):
for step in range(preemption_handler.total_run_calls % STEPS_PER_EPOCH, STEPS_PER_EPOCH):
# distributed_train_step is a single-step training function wrapped by tf.distribute.Strategy.run.
loss += preemption_handler.run(distributed_train_step, args=(next(dataset),))
Not all interruptions come with advance notice so that the
PreemptionCheckpointHandler
can handle them, e.g., those caused by hardware
failure. For a user who saves checkpoints for these cases themselves outside
the PreemptionCheckpointHandler
, if they are using a
tf.train.CheckpointManager
, pass it as the
checkpoint_or_checkpoint_manager
argument to the
PreemptionCheckpointHandler
. If they do not have a
tf.train.CheckpointManager
but are directly working with
tf.train.Checkpoint
, we advise saving the checkpoints in the directory
that's passed as the checkpoint_dir
argument. In this way, at the program
beginning, PreemptionCheckpointHandler
can restore the latest checkpoint
from the directory, no matter it's saved by the user themselves or saved by
the PreemptionCheckpointHandler
before preemption happens.
If a user cannot infer the start epoch and start step from
PreemptionCheckpointHandler.total_run_calls
(e.g., if there is no preknown
STEPS_PER_EPOCH
or if their STEPS_PER_EPOCH
may vary from epoch to epoch),
we recommend tracking the epoch and step numbers themselves and save them in
the passed-in checkpoint:
strategy = tf.distribute.MultiWorkerMirroredStrategy()
trained_epoch = tf.Variable(initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')
step_in_epoch = tf.Variable(initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='step_in_epoch')
with strategy.scope():
dataset, model, optimizer = ...
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
model=model,
trained_epoch=trained_epoch,
step_in_epoch=step_in_epoch)
preemption_handler = tf.distribute.experimental.PreemptionCheckpointHandler(cluster_resolver, checkpoint, checkpoint_dir)
while trained_epoch.numpy() < NUM_EPOCH:
while step_in_epoch.numpy() < STEPS_PER_EPOCH:
loss += failure_handler.run(train_step, args=(next(iterator),))
step_in_epoch.assign_add(1)
...
epoch.assign_add(1)
step_in_epoch.assign(0)
A note on the platform:
PreemptionCheckpointHandler
can only handle the kind of termination with
advance notice. For now, the API recognizes the Google Borg and the Google
Cloud Platform, where it can automatically adopt the correct
preemption/maintenance notification detection mechanism. Users of other
platforms can configure it through a
tf.distribute.experimental.TerminationConfig
. Customization for the exit
behavior and grace period length could also be done here.
Args | |
---|---|
cluster_resolver
|
a tf.distribute.cluster_resolver.ClusterResolver
object. You may also obtain it through the cluster_resolver attribute
of the distribution strategy in use.
|
checkpoint_or_checkpoint_manager
|
a tf.train.CheckpointManager or a
tf.train.Checkpoint . If you are using a tf.train.CheckpointManager
to manage checkpoints outside the PreemptionCheckpointHandler for
backup purpose as well, pass it as checkpoint_or_checkpoint_manager
argument. Otherwise, pass a tf.train.Checkpoint and the
PreemptionCheckpointHandler will create
a tf.train.CheckpointManager to manage it in the checkpoint_dir .
|
checkpoint_dir
|
a directory where the PreemptionCheckpointHandler saves
and restores checkpoints. When a PreemptionCheckpointHandler is
created, the latest checkpoint in the checkpoint_dir will be restored.
(This is not needed if a tf.train.CheckpointManager instead of a
tf.train.Checkpoint is passed as the
checkpoint_or_checkpoint_manager argument.)
|
termination_config
|
optional, a
tf.distribute.experimental.TerminationConfig object to configure for a
platform other than Google Borg or GCP.
|
Attributes | |
---|---|
total_run_calls
|
Returns the number of times PreemptionCheckpointHandler.run is called.
This value tracks the number of all calls to
|
Methods
run
run(
distributed_train_function, *args, **kwargs
)
Runs a training function with error and preemption handling.
This function handles the preemption signal from any peer in the cluster by
saving the training progress and exiting gracefully. It will
also broadcase any program error encountered during the execution of
distributed_train_function
to all workers so that they can raise the same
error.
The distributed_train_function
argument should be a distributed train
function (i.e., containing a call to tf.distribute.Strategy.run
). For
tf.distribute.MultiWorkerMirroredStrategy
users, we recommend passing in a
single-step distributed_train_function
to
PreemptionCheckpointHandler.run
so that the checkpoint can be saved in
time in case a preemption signal or maintenance notice is sent.
Besides the preemption and error handling part,
PreemptionCheckpointHandler.run(distributed_train_function, *args,
**kwargs)
has the same effect and output as
distributed_train_function(*args, **kwargs)
. distributed_train_function
can return either some or no result. The following is a shortened example:
@tf.function
def distributed_train_step(iterator):
# A distributed single-step training function.
def step_fn(inputs):
# A per-replica single-step training function.
x, y = inputs
...
return loss
per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
return strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
for epoch in range(preemption_handler.total_run_calls // STEPS_PER_EPOCH,
EPOCHS_TO_RUN):
iterator = iter(multi_worker_dataset)
total_loss = 0.0
num_batches = 0
for step in range(preemption_handler.total_run_calls % STEPS_PER_EPOCH,
STEPS_PER_EPOCH):
total_loss += preemption_handler.run(distributed_train_step)
num_batches += 1
train_loss = total_loss / num_batches
print('Epoch: %d, train_loss: %f.' %(epoch.numpy(), train_loss))
train_accuracy.reset_states()
Args | |
---|---|
distributed_train_function
|
A (single-step) distributed training function. |
*args
|
args for distributed_train_function .
|
**kwargs
|
kwargs for distributed_train_function .
|
Raises | |
---|---|
Program error encountered by any member in the cluster while executing the
distributed_train_function , or any error from the program error
propagation process.
|
Returns | |
---|---|
Result of running the distributed_train_function .
|