![]() |
![]() |
![]() |
![]() |
Fault tolerance refers to a mechanism of periodically saving the states of trackable objects, such as parameters and models. This enables you to recover them in the event of a program/machine failure during training.
This guide first demonstrates how to add fault tolerance to training with tf.estimator.Estimator
in TensorFlow 1 by specifying metric saving with tf.estimator.RunConfig
. Then, you will learn how to implement fault tolerance for training in Tensorflow 2 in two ways:
- If you use the Keras
Model.fit
API, you can pass thetf.keras.callbacks.BackupAndRestore
callback to it. - If you use a custom training loop (with
tf.GradientTape
), you can arbitrarily save checkpoints using thetf.train.Checkpoint
andtf.train.CheckpointManager
APIs.
Both of these methods will back up and restore the training states in checkpoint files.
Setup
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
TensorFlow 1: Save checkpoints with tf.estimator.RunConfig
In TensorFlow 1, you can configure a tf.estimator
to save checkpoints every step by configuring tf.estimator.RunConfig
.
In this example, start by writing a hook that artificially throws an error during the fifth checkpoint:
class InterruptHook(tf1.train.SessionRunHook):
# A hook for artificially interrupting training.
def begin(self):
self._step = -1
def before_run(self, run_context):
self._step += 1
def after_run(self, run_context, run_values):
if self._step == 5:
raise RuntimeError('Interruption')
Next, configure tf.estimator.Estimator
to save every checkpoint and use the MNIST dataset:
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead. WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead.
Begin training the model. An artificial exception will be raised by the hook you defined earlier.
try:
classifier.train(input_fn=train_input_fn,
hooks=[InterruptHook()],
max_steps=10)
except Exception as e:
print(f'{type(e).__name__}:{e}')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:65: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py:914: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:loss = 118.92192, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3... INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4... INFO:tensorflow:Saving checkpoints for 4 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5... INFO:tensorflow:Saving checkpoints for 5 into /tmp/tmpv15yxr9g/model.ckpt. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1054: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to delete files with this prefix. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... RuntimeError:Interruption
Rebuild the tf.estimator.Estimator
using the last saved checkpoint and continue training:
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
classifier.train(input_fn=train_input_fn,
max_steps = 10)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpv15yxr9g/model.ckpt-6 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1161: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file utilities to get mtimes. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7... INFO:tensorflow:Saving checkpoints for 7 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7... INFO:tensorflow:loss = 105.44863, step = 6 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8... INFO:tensorflow:Saving checkpoints for 8 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9... INFO:tensorflow:Saving checkpoints for 9 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 100.47882. <tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7fcfe8165150>
TensorFlow 2: Back up and restore with a callback and Model.fit
In TensorFlow 2, if you use the Keras Model.fit
API for training, you can provide the tf.keras.callbacks.BackupAndRestore
callback to add the fault tolerance functionality.
To help demonstrate this, let's first start by defining a callback class that artificially throws an error during the fifth checkpoint:
class InterruptingCallback(tf.keras.callbacks.Callback):
# A callback for artificially interrupting training.
def on_epoch_end(self, epoch, log=None):
if epoch == 4:
raise RuntimeError('Interruption')
Then, define and instantiate a simple Keras model, define the loss function, call Model.compile
, and set up a tf.keras.callbacks.BackupAndRestore
callback that will save the checkpoints in a temporary directory:
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
log_dir = tempfile.mkdtemp()
backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
backup_dir = log_dir
)
Now, start training the model with Model.fit
. During training, checkpoints will be saved thanks to the backup_restore_callback
defined above, while the InterruptingCallback
will raise an artificial exception to simulate a failure.
try:
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback, InterruptingCallback()])
except Exception as e:
print(f'{type(e).__name__}:{e}')
Epoch 1/10 1875/1875 [==============================] - 3s 2ms/step - loss: 0.2186 - accuracy: 0.9352 - val_loss: 0.1267 - val_accuracy: 0.9615 Epoch 2/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0967 - accuracy: 0.9700 - val_loss: 0.0910 - val_accuracy: 0.9718 Epoch 3/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0687 - accuracy: 0.9784 - val_loss: 0.0679 - val_accuracy: 0.9797 Epoch 4/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0527 - accuracy: 0.9829 - val_loss: 0.0623 - val_accuracy: 0.9814 Epoch 5/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0434 - accuracy: 0.9857RuntimeError:Interruption
Next, instantiate the Keras model, call Model.compile
, and continue training the model with Model.fit
from a previously saved checkpoint:
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback])
Epoch 6/10 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0370 - accuracy: 0.9879 - val_loss: 0.0732 - val_accuracy: 0.9791 Epoch 7/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0306 - accuracy: 0.9898 - val_loss: 0.0601 - val_accuracy: 0.9827 Epoch 8/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0259 - accuracy: 0.9913 - val_loss: 0.0655 - val_accuracy: 0.9819 Epoch 9/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0244 - accuracy: 0.9918 - val_loss: 0.0746 - val_accuracy: 0.9812 Epoch 10/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0221 - accuracy: 0.9923 - val_loss: 0.0818 - val_accuracy: 0.9813 <keras.callbacks.History at 0x7fcfe0647350>
TensorFlow 2: Write manual checkpoints with a custom training loop
If you use a custom training loop in TensorFlow 2, you can implement a fault tolerance mechanism with the tf.train.Checkpoint
and tf.train.CheckpointManager
APIs.
This example demonstrates how to:
- Use a
tf.train.Checkpoint
object to manually create a checkpoint, where the trackable objects you want to save are set as attributes. - Use a
tf.train.CheckpointManager
to manage multiple checkpoints.
Start by defining and instantiating the Keras model, the optimizer, and the loss function. Then, create a Checkpoint
that manages two objects with trackable states (the model and the optimizer), as well as a CheckpointManager
for logging and keeping several checkpoints in a temporary directory.
model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, log_dir, max_to_keep=2)
Now, implement a custom training loop where after the first epoch every time a new epoch starts the last checkpoint is loaded:
for epoch in range(epochs):
if epoch > 0:
tf.train.load_checkpoint(save_path)
print(f"\nStart of epoch {epoch}")
for step in range(steps_per_epoch):
with tf.GradientTape() as tape:
logits = model(x_train, training=True)
loss_value = loss_fn(y_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
save_path = checkpoint_manager.save()
print(f"Checkpoint saved to {save_path}")
print(f"Training loss at step {step}: {loss_value}")
Start of epoch 0 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-1 Training loss at step 0: 2.3636362552642822 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-2 Training loss at step 1: 2.3626415729522705 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-3 Training loss at step 2: 2.3613197803497314 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-4 Training loss at step 3: 2.360600233078003 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-5 Training loss at step 4: 2.3589422702789307 Start of epoch 1 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-6 Training loss at step 0: 2.3563339710235596 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-7 Training loss at step 1: 2.3568854331970215 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-8 Training loss at step 2: 2.354109287261963 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-9 Training loss at step 3: 2.3532731533050537 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-10 Training loss at step 4: 2.351112127304077 Start of epoch 2 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-11 Training loss at step 0: 2.348905563354492 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-12 Training loss at step 1: 2.349478006362915 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-13 Training loss at step 2: 2.3487260341644287 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-14 Training loss at step 3: 2.345991611480713 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-15 Training loss at step 4: 2.3451104164123535 Start of epoch 3 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-16 Training loss at step 0: 2.3441312313079834 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-17 Training loss at step 1: 2.341529130935669 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-18 Training loss at step 2: 2.342329263687134 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-19 Training loss at step 3: 2.340449571609497 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-20 Training loss at step 4: 2.3367927074432373 Start of epoch 4 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-21 Training loss at step 0: 2.3366076946258545 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-22 Training loss at step 1: 2.335028886795044 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-23 Training loss at step 2: 2.3338520526885986 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-24 Training loss at step 3: 2.3345272541046143 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-25 Training loss at step 4: 2.332385301589966
Next steps
To learn more about fault tolerance and checkpointing in TensorFlow 2, consider the following documentation:
- The
tf.keras.callbacks.BackupAndRestore
callback API docs. - The
tf.train.Checkpoint
andtf.train.CheckpointManager
API docs. - The Training checkpoints guide, including the Writing checkpoints section.
You may also find the following material related to distributed training useful:
- The Fault tolerance section in the Multi-worker training with Keras tutorial.
- The Handing task failure section in the Parameter server training tutorial.