![]() |
![]() |
![]() |
![]() |
Continually saving the "best" model or model weights/parameters has many benefits. These include being able to track the training progress and load saved models from different saved states.
In TensorFlow 1, to configure checkpoint saving during training/validation with the tf.estimator.Estimator
APIs, you specify a schedule in tf.estimator.RunConfig
or use tf.estimator.CheckpointSaverHook
. This guide demonstrates how to migrate from this workflow to TensorFlow 2 Keras APIs.
In TensorFlow 2, you can configure tf.keras.callbacks.ModelCheckpoint
in a number of ways:
- Save the "best" version according to a metric monitored using the
save_best_only=True
parameter, wheremonitor
can be, for example,'loss'
,'val_loss'
,'accuracy', or
'val_accuracy'`. - Save continually at a certain frequency (using the
save_freq
argument). - Save the weights/parameters only instead of the whole model by setting
save_weights_only
toTrue
.
For more details, refer to the tf.keras.callbacks.ModelCheckpoint
API docs and the Save checkpoints during training section in the Save and load models tutorial. Learn more about the Checkpoint format in the TF Checkpoint format section in the Save and load Keras models guide. In addition, to add fault tolerance, you can use tf.keras.callbacks.BackupAndRestore
or tf.train.Checkpoint
for manual checkpointing. Learn more in the Fault tolerance migration guide.
Keras callbacks are objects that are called at different points during training/evaluation/prediction in the built-in Keras Model.fit
/Model.evaluate
/Model.predict
APIs. Learn more in the Next steps section at the end of the guide.
Setup
Start with imports and a simple dataset for demonstration purposes:
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
2022-12-14 03:31:46.680325: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 03:31:46.680417: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 03:31:46.680427: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
TensorFlow 1: Save checkpoints with tf.estimator APIs
This TensorFlow 1 example shows how to configure tf.estimator.RunConfig
to save checkpoints at every step during training/evaluation with the tf.estimator.Estimator
APIs:
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
test_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_test},
y=y_test.astype(np.int32),
num_epochs=10,
shuffle=False
)
train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
steps=10,
throttle_secs=0)
tf1.estimator.train_and_evaluate(estimator=classifier,
train_spec=train_spec,
eval_spec=eval_spec)
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpg36p1m5q', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/tmp/ipykernel_109103/3980459272.py:18: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_109103/3980459272.py:18: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 1 or save_checkpoints_secs None. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:60: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:910: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. 2022-12-14 03:31:52.550049: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT64 } } } is neither a subtype nor a supertype of the combined inputs preceding it: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT32 } } } while inferring type of node 'dnn/zero_fraction/cond/output/_18' INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpg36p1m5q/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into /tmpfs/tmp/tmpg36p1m5q/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T03:31:53 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpg36p1m5q/model.ckpt-1 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.30245s INFO:tensorflow:Finished evaluation at 2022-12-14-03:31:53 INFO:tensorflow:Saving dict for global step 1: accuracy = 0.12734374, average_loss = 2.3120413, global_step = 1, loss = 295.94128 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1: /tmpfs/tmp/tmpg36p1m5q/model.ckpt-1 INFO:tensorflow:loss = 123.480286, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmpfs/tmp/tmpg36p1m5q/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T03:31:54 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpg36p1m5q/model.ckpt-2 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.29533s INFO:tensorflow:Finished evaluation at 2022-12-14-03:31:54 INFO:tensorflow:Saving dict for global step 2: accuracy = 0.1859375, average_loss = 2.243344, global_step = 2, loss = 287.14804 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 2: /tmpfs/tmp/tmpg36p1m5q/model.ckpt-2 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3... INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmpg36p1m5q/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T03:31:54 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpg36p1m5q/model.ckpt-3 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.28706s INFO:tensorflow:Finished evaluation at 2022-12-14-03:31:55 INFO:tensorflow:Saving dict for global step 3: accuracy = 0.24375, average_loss = 2.1928506, global_step = 3, loss = 280.68488 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmpfs/tmp/tmpg36p1m5q/model.ckpt-3 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4... INFO:tensorflow:Saving checkpoints for 4 into /tmpfs/tmp/tmpg36p1m5q/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T03:31:55 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpg36p1m5q/model.ckpt-4 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.28551s INFO:tensorflow:Finished evaluation at 2022-12-14-03:31:55 INFO:tensorflow:Saving dict for global step 4: accuracy = 0.296875, average_loss = 2.1503491, global_step = 4, loss = 275.2447 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 4: /tmpfs/tmp/tmpg36p1m5q/model.ckpt-4 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5... INFO:tensorflow:Saving checkpoints for 5 into /tmpfs/tmp/tmpg36p1m5q/model.ckpt. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/saver.py:1064: remove_checkpoint (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to delete files with this prefix. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T03:31:56 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpg36p1m5q/model.ckpt-5 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.29568s INFO:tensorflow:Finished evaluation at 2022-12-14-03:31:56 INFO:tensorflow:Saving dict for global step 5: accuracy = 0.33984375, average_loss = 2.110334, global_step = 5, loss = 270.12274 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5: /tmpfs/tmp/tmpg36p1m5q/model.ckpt-5 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmpfs/tmp/tmpg36p1m5q/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T03:31:56 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpg36p1m5q/model.ckpt-6 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.29604s INFO:tensorflow:Finished evaluation at 2022-12-14-03:31:56 INFO:tensorflow:Saving dict for global step 6: accuracy = 0.38359374, average_loss = 2.070732, global_step = 6, loss = 265.0537 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6: /tmpfs/tmp/tmpg36p1m5q/model.ckpt-6 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7... INFO:tensorflow:Saving checkpoints for 7 into /tmpfs/tmp/tmpg36p1m5q/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T03:31:57 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpg36p1m5q/model.ckpt-7 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.28699s INFO:tensorflow:Finished evaluation at 2022-12-14-03:31:57 INFO:tensorflow:Saving dict for global step 7: accuracy = 0.4171875, average_loss = 2.0279684, global_step = 7, loss = 259.57996 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 7: /tmpfs/tmp/tmpg36p1m5q/model.ckpt-7 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8... INFO:tensorflow:Saving checkpoints for 8 into /tmpfs/tmp/tmpg36p1m5q/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T03:31:57 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpg36p1m5q/model.ckpt-8 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.28880s INFO:tensorflow:Finished evaluation at 2022-12-14-03:31:58 INFO:tensorflow:Saving dict for global step 8: accuracy = 0.45, average_loss = 1.9848177, global_step = 8, loss = 254.05667 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8: /tmpfs/tmp/tmpg36p1m5q/model.ckpt-8 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9... INFO:tensorflow:Saving checkpoints for 9 into /tmpfs/tmp/tmpg36p1m5q/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T03:31:58 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpg36p1m5q/model.ckpt-9 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.28838s INFO:tensorflow:Finished evaluation at 2022-12-14-03:31:58 INFO:tensorflow:Saving dict for global step 9: accuracy = 0.490625, average_loss = 1.9399147, global_step = 9, loss = 248.30908 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 9: /tmpfs/tmp/tmpg36p1m5q/model.ckpt-9 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmpfs/tmp/tmpg36p1m5q/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T03:31:58 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpg36p1m5q/model.ckpt-10 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.28851s INFO:tensorflow:Finished evaluation at 2022-12-14-03:31:59 INFO:tensorflow:Saving dict for global step 10: accuracy = 0.528125, average_loss = 1.8967623, global_step = 10, loss = 242.78557 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmpfs/tmp/tmpg36p1m5q/model.ckpt-10 INFO:tensorflow:Loss for final step: 93.67593. ({'accuracy': 0.528125, 'average_loss': 1.8967623, 'loss': 242.78557, 'global_step': 10}, [])
%ls {classifier.model_dir}
checkpoint eval/ events.out.tfevents.1670988712.kokoro-gcp-ubuntu-prod-1844974797 graph.pbtxt model.ckpt-10.data-00000-of-00001 model.ckpt-10.index model.ckpt-10.meta model.ckpt-6.data-00000-of-00001 model.ckpt-6.index model.ckpt-6.meta model.ckpt-7.data-00000-of-00001 model.ckpt-7.index model.ckpt-7.meta model.ckpt-8.data-00000-of-00001 model.ckpt-8.index model.ckpt-8.meta model.ckpt-9.data-00000-of-00001 model.ckpt-9.index model.ckpt-9.meta
TensorFlow 2: Save checkpoints with a Keras callback for Model.fit
In TensorFlow 2, when you use the built-in Keras Model.fit
(or Model.evaluate
) for training/evaluation, you can configure tf.keras.callbacks.ModelCheckpoint
and then pass it to the callbacks
parameter of Model.fit
(or Model.evaluate
). (Learn more in the API docs and the Using callbacks section in the Training and evaluation with the built-in methods guide.)
In the example below, you will use a tf.keras.callbacks.ModelCheckpoint
callback to store checkpoints in a temporary directory:
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
steps_per_execution=10)
log_dir = tempfile.mkdtemp()
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=log_dir)
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[model_checkpoint_callback])
Epoch 1/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.2209 - accuracy: 0.9345INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp42o1puju/assets 1875/1875 [==============================] - 5s 2ms/step - loss: 0.2200 - accuracy: 0.9347 - val_loss: 0.1060 - val_accuracy: 0.9677 Epoch 2/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0973 - accuracy: 0.9697INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp42o1puju/assets 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0971 - accuracy: 0.9698 - val_loss: 0.0783 - val_accuracy: 0.9746 Epoch 3/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0703 - accuracy: 0.9778INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp42o1puju/assets 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0702 - accuracy: 0.9778 - val_loss: 0.0754 - val_accuracy: 0.9766 Epoch 4/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0534 - accuracy: 0.9829INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp42o1puju/assets 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0536 - accuracy: 0.9828 - val_loss: 0.0708 - val_accuracy: 0.9788 Epoch 5/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0425 - accuracy: 0.9858INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp42o1puju/assets 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0427 - accuracy: 0.9857 - val_loss: 0.0630 - val_accuracy: 0.9809 Epoch 6/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0351 - accuracy: 0.9888INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp42o1puju/assets 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0353 - accuracy: 0.9888 - val_loss: 0.0678 - val_accuracy: 0.9800 Epoch 7/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0303 - accuracy: 0.9892INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp42o1puju/assets 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0303 - accuracy: 0.9892 - val_loss: 0.0654 - val_accuracy: 0.9827 Epoch 8/10 1850/1875 [============================>.] - ETA: 0s - loss: 0.0272 - accuracy: 0.9911INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp42o1puju/assets 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0275 - accuracy: 0.9911 - val_loss: 0.0711 - val_accuracy: 0.9812 Epoch 9/10 1850/1875 [============================>.] - ETA: 0s - loss: 0.0258 - accuracy: 0.9910INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp42o1puju/assets 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0258 - accuracy: 0.9909 - val_loss: 0.0676 - val_accuracy: 0.9812 Epoch 10/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0227 - accuracy: 0.9924INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp42o1puju/assets 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0227 - accuracy: 0.9924 - val_loss: 0.0829 - val_accuracy: 0.9813 <keras.callbacks.History at 0x7fc1000ebeb0>
%ls {model_checkpoint_callback.filepath}
assets/ fingerprint.pb keras_metadata.pb saved_model.pb variables/
Next steps
Learn more about checkpointing in:
- API docs:
tf.keras.callbacks.ModelCheckpoint
- Tutorial: Save and load models (the Save checkpoints during training section)
- Guide: Save and load Keras models (the TF Checkpoint format section)
Learn more about callbacks in:
- API docs:
tf.keras.callbacks.Callback
- Guide: Writing your own callbacks
- Guide: Training and evaluation with the built-in methods (the Using callbacks section)
You may also find the following migration-related resources useful:
- The Fault tolerance migration guide:
tf.keras.callbacks.BackupAndRestore
forModel.fit
, ortf.train.Checkpoint
andtf.train.CheckpointManager
APIs for a custom training loop - The Early stopping migration guide:
tf.keras.callbacks.EarlyStopping
is a built-in early stopping callback - The TensorBoard migration guide: TensorBoard enables tracking and displaying metrics
- The LoggingTensorHook and StopAtStepHook to Keras callbacks migration guide
- The SessionRunHook to Keras callbacks guide