保存和恢复模型

在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本

可以在训练期间和之后保存模型进度。这意味着模型可以从停止的地方恢复,避免长时间的训练。此外,保存还意味着您可以分享您的模型,其他人可以重现您的工作。在发布研究模型和技术时,大多数机器学习从业者会分享:

  • 用于创建模型的代码
  • 模型训练的权重 (weight) 和参数 (parameters) 。

共享数据有助于其他人了解模型的工作原理,并使用新数据自行尝试。

小心:TensorFlow 模型是代码,对于不受信任的代码,一定要小心。请参阅 安全使用 TensorFlow 以了解详情。

选项

根据您使用的 API,可以通过多种方式保存 TensorFlow 模型。本指南使用 tf.keras,这是一种在 TensorFlow 中构建和训练模型的高级 API。对于其他方式,请参阅使用 SavedModel 格式指南以及保存和加载 Keras 模型指南

配置

安装并导入

安装并导入Tensorflow和依赖项:

pip install pyyaml h5py  # Required to save models in HDF5 format
import os

import tensorflow as tf
from tensorflow import keras

print(tf.version.VERSION)
2022-08-31 05:12:11.359224: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-31 05:12:12.094168: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-31 05:12:12.094412: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-31 05:12:12.094425: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2.10.0-rc3

获取示例数据集

为了演示如何保存和加载权重,您将使用 MNIST 数据集。为了加快运行速度,请使用前 1000 个样本:

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

定义模型

首先构建一个简单的序列(sequential)模型:

# Define a simple sequential model
def create_model():
  model = tf.keras.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

  return model

# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 512)               401920    
                                                                 
 dropout (Dropout)           (None, 512)               0         
                                                                 
 dense_1 (Dense)             (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

在训练期间保存模型(以 checkpoints 形式保存)

您可以使用经过训练的模型而无需重新训练,或者在训练过程中断的情况下从离开处继续训练。tf.keras.callbacks.ModelCheckpoint 回调允许您在训练期间结束时持续保存模型。

Checkpoint 回调用法

创建一个只在训练期间保存权重的 tf.keras.callbacks.ModelCheckpoint 回调:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])  # Pass callback to training

# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
Epoch 1/10
25/32 [======================>.......] - ETA: 0s - loss: 1.2915 - sparse_categorical_accuracy: 0.6325 
Epoch 1: saving model to training_1/cp.ckpt
32/32 [==============================] - 1s 15ms/step - loss: 1.1539 - sparse_categorical_accuracy: 0.6700 - val_loss: 0.7146 - val_sparse_categorical_accuracy: 0.7820
Epoch 2/10
26/32 [=======================>......] - ETA: 0s - loss: 0.4076 - sparse_categorical_accuracy: 0.8906
Epoch 2: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.4042 - sparse_categorical_accuracy: 0.8890 - val_loss: 0.5518 - val_sparse_categorical_accuracy: 0.8270
Epoch 3/10
27/32 [========================>.....] - ETA: 0s - loss: 0.2788 - sparse_categorical_accuracy: 0.9294
Epoch 3: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.2816 - sparse_categorical_accuracy: 0.9310 - val_loss: 0.5021 - val_sparse_categorical_accuracy: 0.8470
Epoch 4/10
25/32 [======================>.......] - ETA: 0s - loss: 0.2144 - sparse_categorical_accuracy: 0.9538
Epoch 4: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.2066 - sparse_categorical_accuracy: 0.9570 - val_loss: 0.4283 - val_sparse_categorical_accuracy: 0.8670
Epoch 5/10
25/32 [======================>.......] - ETA: 0s - loss: 0.1561 - sparse_categorical_accuracy: 0.9650
Epoch 5: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.1588 - sparse_categorical_accuracy: 0.9680 - val_loss: 0.4300 - val_sparse_categorical_accuracy: 0.8650
Epoch 6/10
27/32 [========================>.....] - ETA: 0s - loss: 0.1186 - sparse_categorical_accuracy: 0.9815
Epoch 6: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.1193 - sparse_categorical_accuracy: 0.9800 - val_loss: 0.4233 - val_sparse_categorical_accuracy: 0.8690
Epoch 7/10
25/32 [======================>.......] - ETA: 0s - loss: 0.0804 - sparse_categorical_accuracy: 0.9887
Epoch 7: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0819 - sparse_categorical_accuracy: 0.9880 - val_loss: 0.4121 - val_sparse_categorical_accuracy: 0.8610
Epoch 8/10
26/32 [=======================>......] - ETA: 0s - loss: 0.0626 - sparse_categorical_accuracy: 0.9940
Epoch 8: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0608 - sparse_categorical_accuracy: 0.9940 - val_loss: 0.4123 - val_sparse_categorical_accuracy: 0.8670
Epoch 9/10
26/32 [=======================>......] - ETA: 0s - loss: 0.0549 - sparse_categorical_accuracy: 0.9964
Epoch 9: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0539 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.4397 - val_sparse_categorical_accuracy: 0.8580
Epoch 10/10
26/32 [=======================>......] - ETA: 0s - loss: 0.0480 - sparse_categorical_accuracy: 0.9976
Epoch 10: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0476 - sparse_categorical_accuracy: 0.9980 - val_loss: 0.4327 - val_sparse_categorical_accuracy: 0.8640
<keras.callbacks.History at 0x7f56ee7a3b50>

这将创建一个 TensorFlow checkpoint 文件集合,这些文件在每个 epoch 结束时更新:

os.listdir(checkpoint_dir)
['cp.ckpt.data-00000-of-00001', 'cp.ckpt.index', 'checkpoint']

只要两个模型共享相同的架构,您就可以在它们之间共享权重。因此,当从仅权重恢复模型时,创建一个与原始模型具有相同架构的模型,然后设置其权重。

现在,重新构建一个未经训练的全新模型并基于测试集对其进行评估。未经训练的模型将以机会水平执行(约 10% 的准确率):

# Create a basic model instance
model = create_model()

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 2.3038 - sparse_categorical_accuracy: 0.0760 - 158ms/epoch - 5ms/step
Untrained model, accuracy:  7.60%

然后从 checkpoint 加载权重并重新评估:

# Loads the weights
model.load_weights(checkpoint_path)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4327 - sparse_categorical_accuracy: 0.8640 - 72ms/epoch - 2ms/step
Restored model, accuracy: 86.40%

checkpoint 回调选项

回调提供了几个选项,为 checkpoint 提供唯一名称并调整 checkpoint 频率。

训练一个新模型,每五个 epochs 保存一次唯一命名的 checkpoint :

# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

batch_size = 32

# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq=5*batch_size)

# Create a new model instance
model = create_model()

# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))

# Train the model with the new callback
model.fit(train_images, 
          train_labels,
          epochs=50, 
          batch_size=batch_size, 
          callbacks=[cp_callback],
          validation_data=(test_images, test_labels),
          verbose=0)
Epoch 5: saving model to training_2/cp-0005.ckpt

Epoch 10: saving model to training_2/cp-0010.ckpt

Epoch 15: saving model to training_2/cp-0015.ckpt

Epoch 20: saving model to training_2/cp-0020.ckpt

Epoch 25: saving model to training_2/cp-0025.ckpt

Epoch 30: saving model to training_2/cp-0030.ckpt

Epoch 35: saving model to training_2/cp-0035.ckpt

Epoch 40: saving model to training_2/cp-0040.ckpt

Epoch 45: saving model to training_2/cp-0045.ckpt

Epoch 50: saving model to training_2/cp-0050.ckpt
<keras.callbacks.History at 0x7f56e00d5ee0>

现在,检查生成的检查点并选择最新检查点:

os.listdir(checkpoint_dir)
['cp-0020.ckpt.data-00000-of-00001',
 'cp-0005.ckpt.index',
 'cp-0025.ckpt.data-00000-of-00001',
 'cp-0005.ckpt.data-00000-of-00001',
 'cp-0000.ckpt.data-00000-of-00001',
 'cp-0035.ckpt.index',
 'cp-0045.ckpt.data-00000-of-00001',
 'cp-0015.ckpt.data-00000-of-00001',
 'cp-0040.ckpt.data-00000-of-00001',
 'cp-0050.ckpt.index',
 'cp-0020.ckpt.index',
 'cp-0045.ckpt.index',
 'cp-0025.ckpt.index',
 'cp-0030.ckpt.index',
 'cp-0000.ckpt.index',
 'cp-0030.ckpt.data-00000-of-00001',
 'cp-0040.ckpt.index',
 'cp-0050.ckpt.data-00000-of-00001',
 'cp-0010.ckpt.data-00000-of-00001',
 'cp-0035.ckpt.data-00000-of-00001',
 'checkpoint',
 'cp-0010.ckpt.index',
 'cp-0015.ckpt.index']
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
'training_2/cp-0050.ckpt'

注:默认 TensorFlow 格式只保存最近的 5 个检查点。

要进行测试,请重置模型并加载最新检查点:

# Create a new model instance
model = create_model()

# Load the previously saved weights
model.load_weights(latest)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4798 - sparse_categorical_accuracy: 0.8800 - 156ms/epoch - 5ms/step
Restored model, accuracy: 88.00%

这些文件是什么?

上述代码可将权重存储到检查点格式文件(仅包含二进制格式训练权重) 的合集中。检查点包含:

  • 一个或多个包含模型权重的分片。
  • 一个索引文件,指示哪些权重存储在哪个分片中。

如果您在一台计算机上训练模型,您将获得一个具有如下后缀的分片:.data-00000-of-00001

手动保存权重

要手动保存权重,请使用 tf.keras.Model.save_weights。默认情况下,tf.keras(尤其是 Model.save_weights 方法)使用扩展名为 .ckpt 的 TensorFlow 检查点格式。要以扩展名为 .h5 的 HDF5 格式保存,请参阅保存和加载模型指南。

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Create a new model instance
model = create_model()

# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate
32/32 - 0s - loss: 0.4798 - sparse_categorical_accuracy: 0.8800 - 155ms/epoch - 5ms/step
Restored model, accuracy: 88.00%

保存整个模型

调用 tf.keras.Model.save 以将模型的架构、权重和训练配置保存在单个 file/folder 中。这样一来,您可以导出模型,以便在不访问原始 Python 代码的情况下使用它*。由于优化器状态已经恢复,可以从中断的位置恢复训练。

整个模型可以保存为两种不同的文件格式(SavedModelHDF5)。TensorFlow SavedModel 格式是 TF2.x 中的默认文件格式。但是,模型能够以 HDF5 格式保存。下面详细介绍了如何以两种文件格式保存整个模型。

保存全功能模型会非常有用,您可以在 TensorFlow.js(Saved ModelHDF5)中加载它们,然后在网络浏览器中训练和运行,或者使用 TensorFlow Lite(Saved ModelHDF5)转换它们以在移动设备上运行

自定义对象(例如,子类化模型或层)在保存和加载时需要特别注意。请参阅下面的保存自定义对象*部分。

SavedModel 格式

SavedModel 格式是另一种序列化模型的方式。以这种格式保存的模型可以使用 tf.keras.models.load_model 还原,并且与 TensorFlow Serving 兼容。SavedModel 指南详细介绍了如何 serve/inspect SavedModel。以下部分说明了保存和恢复模型的步骤。

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')
Epoch 1/5
32/32 [==============================] - 0s 2ms/step - loss: 1.1801 - sparse_categorical_accuracy: 0.6520
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4198 - sparse_categorical_accuracy: 0.8870
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2929 - sparse_categorical_accuracy: 0.9250
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2222 - sparse_categorical_accuracy: 0.9470
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1728 - sparse_categorical_accuracy: 0.9590
INFO:tensorflow:Assets written to: saved_model/my_model/assets

SavedModel 格式是一个包含 protobuf 二进制文件和 TensorFlow 检查点的目录。检查保存的模型目录:

# my_model directory
ls saved_model

# Contains an assets folder, saved_model.pb, and variables folder.
ls saved_model/my_model
my_model
assets  keras_metadata.pb  saved_model.pb  variables

从保存的模型重新加载一个新的 Keras 模型:

new_model = tf.keras.models.load_model('saved_model/my_model')

# Check its architecture
new_model.summary()
Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_10 (Dense)            (None, 512)               401920    
                                                                 
 dropout_5 (Dropout)         (None, 512)               0         
                                                                 
 dense_11 (Dense)            (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

使用与原始模型相同的参数编译恢复的模型。尝试使用加载的模型运行评估和预测:

# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

print(new_model.predict(test_images).shape)
32/32 - 0s - loss: 0.4184 - sparse_categorical_accuracy: 0.8590 - 164ms/epoch - 5ms/step
Restored model, accuracy: 85.90%
32/32 [==============================] - 0s 1ms/step
(1000, 10)

HDF5 格式

Keras使用 HDF5 标准提供了一种基本的保存格式。

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')
Epoch 1/5
32/32 [==============================] - 0s 2ms/step - loss: 1.1226 - sparse_categorical_accuracy: 0.6750
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4029 - sparse_categorical_accuracy: 0.8850
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2816 - sparse_categorical_accuracy: 0.9310
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2025 - sparse_categorical_accuracy: 0.9560
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1622 - sparse_categorical_accuracy: 0.9640

现在,从该文件重新创建模型:

# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')

# Show the model architecture
new_model.summary()
Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_12 (Dense)            (None, 512)               401920    
                                                                 
 dropout_6 (Dropout)         (None, 512)               0         
                                                                 
 dense_13 (Dense)            (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

检查其准确率(accuracy):

loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
32/32 - 0s - loss: 0.4360 - sparse_categorical_accuracy: 0.8600 - 157ms/epoch - 5ms/step
Restored model, accuracy: 86.00%

Keras 通过检查模型的架构来保存这些模型。这种技术可以保存所有内容:

  • 权重值
  • 模型的架构
  • 模型的训练配置(您传递给 .compile() 方法的内容)
  • 优化器及其状态(如果有)(这样,您便可从中断的地方重新启动训练)

Keras 无法保存 v1.x 优化器(来自 tf.compat.v1.train),因为它们与检查点不兼容。对于 v1.x 优化器,您需要在加载-失去优化器的状态后,重新编译模型。

保存自定义对象

如果您正在使用 SavedModel 格式,则可以跳过此部分。HDF5 和 SavedModel 之间的主要区别在于,HDF5 使用对象配置来保存模型架构,而 SavedModel 则保存执行计算图。因此,SavedModel 能够在不需要原始代码的情况下保存自定义对象,如子类模型和自定义层。

要将自定义对象保存到 HDF5,您必须执行以下操作:

  1. 在您的对象中定义一个 get_config 方法,并且可以选择定义一个 from_config 类方法。
    • get_config(self) 返回重新创建对象所需的参数的 JSON 可序列化字典。
    • from_config(cls, config) 使用从 get_config 返回的配置来创建一个新对象。默认情况下,此函数将使用配置作为初始化 kwarg (return cls(**config))。
  2. 加载模型时将对象传递给 custom_objects 参数。参数必须是将字符串类名映射到 Python 类的字典。例如 tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})

有关自定义对象和 get_config 的示例,请参阅从头开始编写层和模型教程。

# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.