保存和恢复模型

在 TensorFlow.org 上查看 在 Google Colab 中运行 查看 GitHub 上的源代码

模型进度可在训练期间和之后保存。这意味着,您可以从上次暂停的地方继续训练模型,避免训练时间过长。此外,可以保存意味着您可以分享模型,而他人可以对您的工作成果进行再创作。发布研究模型和相关技术时,大部分机器学习从业者会分享以下内容:

  • 用于创建模型的代码,以及
  • 模型的训练权重或参数

分享此类数据有助于他人了解模型的工作原理并尝试使用新数据自行尝试模型。

选项

您可以通过多种不同的方法保存 TensorFlow 模型,具体取决于您使用的 API。本指南使用的是 tf.keras,它是一种用于在 TensorFlow 中构建和训练模型的高阶 API。要了解其他方法,请参阅 TensorFlow 保存和恢复指南或在 Eager 中保存

设置

安装和导入

安装并导入 TensorFlow 和依赖项:

!pip install -q h5py pyyaml

获取示例数据集

我们将使用 MNIST 数据集训练模型,以演示如何保存权重。要加快演示运行速度,请仅使用前 1000 个样本:

from __future__ import absolute_import, division, print_function

import os

import tensorflow as tf
from tensorflow import keras

tf.__version__
'1.12.0-rc1'
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

定义模型

我们来构建一个简单的模型,以演示如何保存和加载权重。

# Returns a short sequential model
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation=tf.nn.softmax)
  ])

  model.compile(optimizer=tf.keras.optimizers.Adam(),
                loss=tf.keras.losses.sparse_categorical_crossentropy,
                metrics=['accuracy'])

  return model

# Create a basic model instance
model = create_model()
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
dense (Dense)                (None, 512)               401920
_________________________________________________________________
dropout (Dropout)            (None, 512)               0
_________________________________________________________________
dense_1 (Dense)              (None, 10)                5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

在训练期间保存检查点

主要用例是,在训练期间或训练结束时自动保存检查点。这样一来,您便可以使用经过训练的模型,而无需重新训练该模型,或从上次暂停的地方继续训练,以防训练过程中断。

tf.keras.callbacks.ModelCheckpoint 是执行此任务的回调。该回调需要几个参数来配置检查点。

检查点回调用法

训练模型,并将 ModelCheckpoint 回调传递给该模型:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create checkpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

model = create_model()

model.fit(train_images, train_labels,  epochs = 10,
          validation_data = (test_images,test_labels),
          callbacks = [cp_callback])  # pass callback to training
Train on 1000 samples, validate on 1000 samples
Epoch 1/10
 960/1000 [===========================>..] - ETA: 0s - loss: 1.2667 - acc: 0.6531
Epoch 00001: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb8028c2630>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.
1000/1000 [==============================] - 0s 421us/step - loss: 1.2358 - acc: 0.6620 - val_loss: 0.7326 - val_acc: 0.7760
Epoch 2/10
 928/1000 [==========================>...] - ETA: 0s - loss: 0.4499 - acc: 0.8696
Epoch 00002: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb8028c2630>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.
1000/1000 [==============================] - 0s 165us/step - loss: 0.4324 - acc: 0.8760 - val_loss: 0.5149 - val_acc: 0.8430
Epoch 3/10
 832/1000 [=======================>......] - ETA: 0s - loss: 0.2870 - acc: 0.9255
Epoch 00003: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb8028c2630>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.
1000/1000 [==============================] - 0s 177us/step - loss: 0.2847 - acc: 0.9240 - val_loss: 0.4527 - val_acc: 0.8570
Epoch 4/10
 864/1000 [========================>.....] - ETA: 0s - loss: 0.2065 - acc: 0.9572
Epoch 00004: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb8028c2630>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.
1000/1000 [==============================] - 0s 174us/step - loss: 0.2011 - acc: 0.9560 - val_loss: 0.4250 - val_acc: 0.8680
Epoch 5/10
 864/1000 [========================>.....] - ETA: 0s - loss: 0.1503 - acc: 0.9676
Epoch 00005: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb8028c2630>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.
1000/1000 [==============================] - 0s 178us/step - loss: 0.1601 - acc: 0.9620 - val_loss: 0.4080 - val_acc: 0.8660
Epoch 6/10
 864/1000 [========================>.....] - ETA: 0s - loss: 0.1068 - acc: 0.9861
Epoch 00006: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb8028c2630>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.
1000/1000 [==============================] - 0s 175us/step - loss: 0.1076 - acc: 0.9850 - val_loss: 0.4124 - val_acc: 0.8650
Epoch 7/10
 864/1000 [========================>.....] - ETA: 0s - loss: 0.0856 - acc: 0.9861
Epoch 00007: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb8028c2630>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.
1000/1000 [==============================] - 0s 174us/step - loss: 0.0837 - acc: 0.9880 - val_loss: 0.4060 - val_acc: 0.8680
Epoch 8/10
 864/1000 [========================>.....] - ETA: 0s - loss: 0.0626 - acc: 0.9965
Epoch 00008: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb8028c2630>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.
1000/1000 [==============================] - 0s 172us/step - loss: 0.0616 - acc: 0.9970 - val_loss: 0.4055 - val_acc: 0.8720
Epoch 9/10
 832/1000 [=======================>......] - ETA: 0s - loss: 0.0475 - acc: 1.0000
Epoch 00009: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb8028c2630>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.
1000/1000 [==============================] - 0s 176us/step - loss: 0.0455 - acc: 1.0000 - val_loss: 0.3973 - val_acc: 0.8730
Epoch 10/10
 864/1000 [========================>.....] - ETA: 0s - loss: 0.0394 - acc: 0.9965
Epoch 00010: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb8028c2630>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.
1000/1000 [==============================] - 0s 176us/step - loss: 0.0389 - acc: 0.9960 - val_loss: 0.4114 - val_acc: 0.8720

上述代码将创建一个 TensorFlow 检查点文件集合,这些文件在每个周期结束时更新:

!ls {checkpoint_dir}
checkpoint  cp.ckpt.data-00000-of-00001  cp.ckpt.index

创建一个未经训练的全新模型。仅通过权重恢复模型时,您必须有一个与原始模型架构相同的模型。由于模型架构相同,因此我们可以分享权重(尽管是不同的模型实例)。

现在,重新构建一个未经训练的全新模型,并用测试集对其进行评估。未训练模型的表现有很大的偶然性(准确率约为 10%):

model = create_model()

loss, acc = model.evaluate(test_images, test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))
1000/1000 [==============================] - 0s 118us/step
Untrained model, accuracy: 11.10%

然后从检查点加载权重,并重新评估:

model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
1000/1000 [==============================] - 0s 35us/step
Restored model, accuracy: 87.20%

检查点回调选项

该回调提供了多个选项,用于为生成的检查点提供独一无二的名称,以及调整检查点创建频率。

训练一个新模型,每隔 5 个周期保存一次检查点并设置唯一名称:

# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True,
    # Save weights, every 5-epochs.
    period=5)

model = create_model()
model.fit(train_images, train_labels,
          epochs = 50, callbacks = [cp_callback],
          validation_data = (test_images,test_labels),
          verbose=0)

Epoch 00005: saving model to training_2/cp-0005.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb7ff0032e8>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.

Epoch 00010: saving model to training_2/cp-0010.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb7ff0032e8>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.

Epoch 00015: saving model to training_2/cp-0015.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb7ff0032e8>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.

Epoch 00020: saving model to training_2/cp-0020.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb7ff0032e8>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.

Epoch 00025: saving model to training_2/cp-0025.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb7ff0032e8>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.

Epoch 00030: saving model to training_2/cp-0030.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb7ff0032e8>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.

Epoch 00035: saving model to training_2/cp-0035.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb7ff0032e8>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.

Epoch 00040: saving model to training_2/cp-0040.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb7ff0032e8>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.

Epoch 00045: saving model to training_2/cp-0045.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb7ff0032e8>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.

Epoch 00050: saving model to training_2/cp-0050.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb7ff0032e8>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.

现在,看一下生成的检查点并选择最新的检查点:

! ls {checkpoint_dir}
checkpoint            cp-0030.ckpt.data-00000-of-00001
cp-0005.ckpt.data-00000-of-00001  cp-0030.ckpt.index
cp-0005.ckpt.index        cp-0035.ckpt.data-00000-of-00001
cp-0010.ckpt.data-00000-of-00001  cp-0035.ckpt.index
cp-0010.ckpt.index        cp-0040.ckpt.data-00000-of-00001
cp-0015.ckpt.data-00000-of-00001  cp-0040.ckpt.index
cp-0015.ckpt.index        cp-0045.ckpt.data-00000-of-00001
cp-0020.ckpt.data-00000-of-00001  cp-0045.ckpt.index
cp-0020.ckpt.index        cp-0050.ckpt.data-00000-of-00001
cp-0025.ckpt.data-00000-of-00001  cp-0050.ckpt.index
cp-0025.ckpt.index
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
'training_2/cp-0050.ckpt'

要进行测试,请重置模型并加载最新的检查点:

model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
1000/1000 [==============================] - 0s 89us/step
Restored model, accuracy: 87.60%

这些是什么文件?

上述代码将权重存储在检查点格式的文件集合中,这些文件仅包含经过训练的权重(采用二进制格式)。检查点包括: * 包含模型权重的一个或多个分片。 * 指示哪些权重存储在哪些分片中的索引文件。

如果您仅在一台机器上训练模型,则您将有 1 个后缀为 .data-00000-of-00001 的分片

手动保存权重

在上文中,您了解了如何将权重加载到模型中。

手动保存权重的方法同样也很简单,只需使用 Model.save_weights 方法即可。

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x7fb86c37a198>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from tf.train.
1000/1000 [==============================] - 0s 90us/step
Restored model, accuracy: 87.60%

保存整个模型

整个模型可以保存到一个文件中,其中包含权重值、模型配置乃至优化器配置。这样,您就可以为模型设置检查点,并稍后从完全相同的状态继续训练,而无需访问原始代码。

在 Keras 中保存完全可正常使用的模型非常有用,您可以在 TensorFlow.js 中加载它们,然后在网络浏览器中训练和运行它们。

Keras 使用 HDF5 标准提供基本的保存格式。对于我们来说,可将保存的模型视为一个二进制 blob。

model = create_model()

model.fit(train_images, train_labels, epochs=5)

# Save entire model to a HDF5 file
model.save('my_model.h5')
Epoch 1/5
1000/1000 [==============================] - 0s 342us/step - loss: 1.1782 - acc: 0.6630
Epoch 2/5
1000/1000 [==============================] - 0s 131us/step - loss: 0.4267 - acc: 0.8730
Epoch 3/5
1000/1000 [==============================] - 0s 130us/step - loss: 0.2926 - acc: 0.9250
Epoch 4/5
1000/1000 [==============================] - 0s 132us/step - loss: 0.2049 - acc: 0.9480
Epoch 5/5
1000/1000 [==============================] - 0s 131us/step - loss: 0.1533 - acc: 0.9640

现在,从该文件重新创建模型:

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('my_model.h5')
new_model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
dense_12 (Dense)             (None, 512)               401920
_________________________________________________________________
dropout_6 (Dropout)          (None, 512)               0
_________________________________________________________________
dense_13 (Dense)             (None, 10)                5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

检查其准确率:

loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
1000/1000 [==============================] - 0s 109us/step
Restored model, accuracy: 86.70%

此技巧可保存以下所有内容:

  • 权重值
  • 模型配置(架构)
  • 优化器配置

Keras 通过检查架构来保存模型。目前,它无法保存 TensorFlow 优化器(来自 tf.train)。使用此类优化器时,您需要在加载模型后对其进行重新编译,使优化器的状态变松散。

后续学习计划

这些就是使用 tf.keras 保存和加载模型的快速指南。

#@title MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.