![]() |
![]() |
![]() |
![]() |
Fault tolerance refers to a mechanism of periodically saving the states of trackable objects, such as parameters and models. This enables you to recover them in the event of a program/machine failure during training.
This guide first demonstrates how to add fault tolerance to training with tf.estimator.Estimator
in TensorFlow 1 by specifying metric saving with tf.estimator.RunConfig
. Then, you will learn how to implement fault tolerance for training in Tensorflow 2 in two ways:
- If you use the Keras
Model.fit
API, you can pass thetf.keras.callbacks.BackupAndRestore
callback to it. - If you use a custom training loop (with
tf.GradientTape
), you can arbitrarily save checkpoints using thetf.train.Checkpoint
andtf.train.CheckpointManager
APIs.
Both of these methods will back up and restore the training states in checkpoint files.
Setup
Install tf-nightly
, as the frequency of checkpoint saving at a particular step with the save_freq
argument in tf.keras.callbacks.BackupAndRestore
is introduced from TensorFlow 2.10:
pip install tf-nightly
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 [==============================] - 0s 0us/step
TensorFlow 1: Save checkpoints with tf.estimator.RunConfig
In TensorFlow 1, you can configure a tf.estimator
to save checkpoints every step by configuring tf.estimator.RunConfig
.
In this example, start by writing a hook that artificially throws an error during the fifth checkpoint:
class InterruptHook(tf1.train.SessionRunHook):
# A hook for artificially interrupting training.
def begin(self):
self._step = -1
def before_run(self, run_context):
self._step += 1
def after_run(self, run_context, run_values):
if self._step == 5:
raise RuntimeError('Interruption')
Next, configure tf.estimator.Estimator
to save every checkpoint and use the MNIST dataset:
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_23249/314197976.py:1: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version. Instructions for updating: Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_23249/314197976.py:2: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_23249/314197976.py:7: DNNClassifier.__init__ (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:807: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp4hyoyuxs', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/tmp/ipykernel_23249/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_23249/314197976.py:17: numpy_input_fn (from tensorflow_estimator.python.estimator.inputs.numpy_io) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead.
Begin training the model. An artificial exception will be raised by the hook you defined earlier.
try:
classifier.train(input_fn=train_input_fn,
hooks=[InterruptHook()],
max_steps=10)
except Exception as e:
print(f'{type(e).__name__}:{e}')
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_23249/2587623597.py:3: object.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:60: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:446: dnn_logit_fn_builder (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1414: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1417: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1454: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:910: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. 2023-02-23 02:26:35.865225: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT64 } } } is neither a subtype nor a supertype of the combined inputs preceding it: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT32 } } } while inferring type of node 'dnn/zero_fraction/cond/output/_18' INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp4hyoyuxs/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into /tmpfs/tmp/tmp4hyoyuxs/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:loss = 118.972084, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmpfs/tmp/tmp4hyoyuxs/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3... INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmp4hyoyuxs/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4... INFO:tensorflow:Saving checkpoints for 4 into /tmpfs/tmp/tmp4hyoyuxs/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5... INFO:tensorflow:Saving checkpoints for 5 into /tmpfs/tmp/tmp4hyoyuxs/model.ckpt. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/saver.py:1067: remove_checkpoint (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to delete files with this prefix. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmpfs/tmp/tmp4hyoyuxs/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... RuntimeError:Interruption
Rebuild the tf.estimator.Estimator
using the last saved checkpoint and continue training:
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
classifier.train(input_fn=train_input_fn,
max_steps = 10)
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp4hyoyuxs', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp4hyoyuxs/model.ckpt-6 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/saver.py:1176: get_checkpoint_mtimes (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file utilities to get mtimes. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmpfs/tmp/tmp4hyoyuxs/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7... INFO:tensorflow:Saving checkpoints for 7 into /tmpfs/tmp/tmp4hyoyuxs/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7... INFO:tensorflow:loss = 104.51564, step = 6 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8... INFO:tensorflow:Saving checkpoints for 8 into /tmpfs/tmp/tmp4hyoyuxs/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9... INFO:tensorflow:Saving checkpoints for 9 into /tmpfs/tmp/tmp4hyoyuxs/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmpfs/tmp/tmp4hyoyuxs/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 90.682465. <tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7f907476af70>
TensorFlow 2: Back up and restore with a callback and Model.fit
In TensorFlow 2, if you use the Keras Model.fit
API for training, you can provide the tf.keras.callbacks.BackupAndRestore
callback to add the fault tolerance functionality.
To help demonstrate this, first start by defining a Keras Callback
class that artificially throws an error during the fourth epoch checkpoint:
class InterruptAtEpoch(tf.keras.callbacks.Callback):
# A callback for artificially interrupting training.
def __init__(self, interrupting_epoch=3):
self.interrupting_epoch = interrupting_epoch
def on_epoch_end(self, epoch, log=None):
if epoch == self.interrupting_epoch:
raise RuntimeError('Interruption')
Then, define and instantiate a simple Keras model, define the loss function, call Model.compile
, and set up a tf.keras.callbacks.BackupAndRestore
callback that will save the checkpoints in a temporary directory at epoch boundaries:
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'])
log_dir = tempfile.mkdtemp()
backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
backup_dir = log_dir)
Start training the model with Model.fit
. During training, checkpoints will be saved thanks to tf.keras.callbacks.BackupAndRestore
instantiated above, while the InterruptAtEpoch
class will raise an artificial exception to simulate a failure after the fourth epoch.
try:
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback, InterruptAtEpoch()])
except Exception as e:
print(f'{type(e).__name__}:{e}')
Epoch 1/10 100/100 [==============================] - 2s 11ms/step - loss: 0.4589 - accuracy: 0.8716 - val_loss: 0.2179 - val_accuracy: 0.9394 Epoch 2/10 100/100 [==============================] - 1s 8ms/step - loss: 0.1984 - accuracy: 0.9430 - val_loss: 0.1544 - val_accuracy: 0.9570 Epoch 3/10 100/100 [==============================] - 1s 8ms/step - loss: 0.1460 - accuracy: 0.9575 - val_loss: 0.1221 - val_accuracy: 0.9643 Epoch 4/10 90/100 [==========================>...] - ETA: 0s - loss: 0.1144 - accuracy: 0.9677RuntimeError:Interruption
Next, instantiate the Keras model, call Model.compile
, and continue training the model with Model.fit
from a previously saved checkpoint:
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback])
Epoch 5/10 100/100 [==============================] - 2s 21ms/step - loss: 0.0928 - accuracy: 0.9735 - val_loss: 0.0883 - val_accuracy: 0.9734 Epoch 6/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0785 - accuracy: 0.9778 - val_loss: 0.0806 - val_accuracy: 0.9763 Epoch 7/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0672 - accuracy: 0.9806 - val_loss: 0.0723 - val_accuracy: 0.9780 Epoch 8/10 100/100 [==============================] - 1s 5ms/step - loss: 0.0572 - accuracy: 0.9832 - val_loss: 0.0687 - val_accuracy: 0.9792 Epoch 9/10 100/100 [==============================] - 1s 6ms/step - loss: 0.0503 - accuracy: 0.9855 - val_loss: 0.0666 - val_accuracy: 0.9792 Epoch 10/10 100/100 [==============================] - 1s 6ms/step - loss: 0.0441 - accuracy: 0.9872 - val_loss: 0.0642 - val_accuracy: 0.9796 <keras.callbacks.History at 0x7f90507d0c70>
Define another Callback
class that artificially throws an error during the 140th step:
class InterruptAtStep(tf.keras.callbacks.Callback):
# A callback for artificially interrupting training.
def __init__(self, interrupting_step=140):
self.total_step_count = 0
self.interrupting_step = interrupting_step
def on_batch_begin(self, batch, logs=None):
self.total_step_count += 1
def on_batch_end(self, batch, logs=None):
if self.total_step_count == self.interrupting_step:
print("\nInterrupting at step count", self.total_step_count)
raise RuntimeError('Interruption')
To make sure the checkpoints are saved every 30 steps, set the save_freq
in the BackupAndRestore
callback to 30
. The InterruptAtStep
will raise an artificial exception to simulate a failure at epoch 1 and step 40 (total step count 140). The checkpoint would be last saved at epoch 1 and step 20.
log_dir_2 = tempfile.mkdtemp()
backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
backup_dir = log_dir_2, save_freq=30
)
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'])
try:
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback, InterruptAtStep()])
except Exception as e:
print(f'{type(e).__name__}:{e}')
Epoch 1/10 100/100 [==============================] - 2s 11ms/step - loss: 0.4662 - accuracy: 0.8687 - val_loss: 0.2200 - val_accuracy: 0.9363 Epoch 2/10 25/100 [======>.......................] - ETA: 0s - loss: 0.2307 - accuracy: 0.9355 Interrupting at step count 140 RuntimeError:Interruption
Next, instantiate the Keras model, call Model.compile
, and continue training the model with Model.fit
from a previously saved checkpoint. Notice that the training starts from epoch 2 and step 21.
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback])
Epoch 2/10 100/100 [==============================] - 2s 18ms/step - loss: 0.1923 - accuracy: 0.9450 - val_loss: 0.1600 - val_accuracy: 0.9538 Epoch 3/10 100/100 [==============================] - 0s 5ms/step - loss: 0.1459 - accuracy: 0.9579 - val_loss: 0.1264 - val_accuracy: 0.9624 Epoch 4/10 100/100 [==============================] - 0s 5ms/step - loss: 0.1144 - accuracy: 0.9678 - val_loss: 0.1028 - val_accuracy: 0.9697 Epoch 5/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0943 - accuracy: 0.9735 - val_loss: 0.0920 - val_accuracy: 0.9721 Epoch 6/10 100/100 [==============================] - 1s 5ms/step - loss: 0.0815 - accuracy: 0.9765 - val_loss: 0.0844 - val_accuracy: 0.9757 Epoch 7/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0691 - accuracy: 0.9802 - val_loss: 0.0770 - val_accuracy: 0.9764 Epoch 8/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0607 - accuracy: 0.9826 - val_loss: 0.0713 - val_accuracy: 0.9772 Epoch 9/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0516 - accuracy: 0.9856 - val_loss: 0.0661 - val_accuracy: 0.9814 Epoch 10/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0472 - accuracy: 0.9863 - val_loss: 0.0649 - val_accuracy: 0.9791 <keras.callbacks.History at 0x7f905068bd30>
TensorFlow 2: Write manual checkpoints with a custom training loop
If you use a custom training loop in TensorFlow 2, you can implement a fault tolerance mechanism with the tf.train.Checkpoint
and tf.train.CheckpointManager
APIs.
This example demonstrates how to:
- Use a
tf.train.Checkpoint
object to manually create a checkpoint, where the trackable objects you want to save are set as attributes. - Use a
tf.train.CheckpointManager
to manage multiple checkpoints.
Start by defining and instantiating the Keras model, the optimizer, and the loss function. Then, create a Checkpoint
that manages two objects with trackable states (the model and the optimizer), as well as a CheckpointManager
for logging and keeping several checkpoints in a temporary directory.
model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, log_dir, max_to_keep=2)
Now, implement a custom training loop where after the first epoch every time a new epoch starts the last checkpoint is loaded:
for epoch in range(epochs):
if epoch > 0:
tf.train.load_checkpoint(save_path)
print(f"\nStart of epoch {epoch}")
for step in range(steps_per_epoch):
with tf.GradientTape() as tape:
logits = model(x_train, training=True)
loss_value = loss_fn(y_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
save_path = checkpoint_manager.save()
print(f"Checkpoint saved to {save_path}")
print(f"Training loss at step {step}: {loss_value}")
Start of epoch 0 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-1 Training loss at step 0: 2.3765370845794678 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-2 Training loss at step 1: 2.3751797676086426 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-3 Training loss at step 2: 2.372164249420166 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-4 Training loss at step 3: 2.3722121715545654 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-5 Training loss at step 4: 2.37064790725708 Start of epoch 1 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-6 Training loss at step 0: 2.369161367416382 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-7 Training loss at step 1: 2.367633819580078 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-8 Training loss at step 2: 2.3669111728668213 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-9 Training loss at step 3: 2.3661725521087646 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-10 Training loss at step 4: 2.364102840423584 Start of epoch 2 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-11 Training loss at step 0: 2.363913059234619 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-12 Training loss at step 1: 2.361424684524536 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-13 Training loss at step 2: 2.358936309814453 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-14 Training loss at step 3: 2.359144687652588 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-15 Training loss at step 4: 2.3580069541931152 Start of epoch 3 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-16 Training loss at step 0: 2.3552677631378174 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-17 Training loss at step 1: 2.354823350906372 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-18 Training loss at step 2: 2.3539443016052246 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-19 Training loss at step 3: 2.351090908050537 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-20 Training loss at step 4: 2.3506886959075928 Start of epoch 4 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-21 Training loss at step 0: 2.3491249084472656 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-22 Training loss at step 1: 2.3474254608154297 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-23 Training loss at step 2: 2.346055269241333 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-24 Training loss at step 3: 2.343763828277588 Checkpoint saved to /tmpfs/tmp/tmpnf0rsxnw/ckpt-25 Training loss at step 4: 2.3445236682891846
Next steps
To learn more about fault tolerance and checkpointing in TensorFlow 2, consider the following documentation:
- The
tf.keras.callbacks.BackupAndRestore
callback API docs. - The
tf.train.Checkpoint
andtf.train.CheckpointManager
API docs. - The Training checkpoints guide, including the Writing checkpoints section.
You may also find the following material related to distributed training useful:
- The Fault tolerance section in the Multi-worker training with Keras tutorial.
- The Handing task failure section in the Parameter server training tutorial.