训练检查点

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

“保存 TensorFlow 模型”这一短语通常表示保存以下两种元素之一:

  1. 检查点,或
  2. SavedModel。

检查点可以捕获模型使用的所有参数(tf.Variable 对象)的确切值。检查点不包含对模型所定义计算的任何描述,因此通常仅在将使用保存参数值的源代码可用时才有用。

另一方面,除了参数值(检查点)之外,SavedModel 格式还包括对模型所定义计算的序列化描述。这种格式的模型独立于创建模型的源代码。因此,它们适合通过 TensorFlow Serving、TensorFlow Lite、TensorFlow.js 或者使用其他编程语言(C、C++、Java、Go、Rust、C# 等 TensorFlow API)编写的程序进行部署。

本文介绍用于编写和读取检查点的 API。

设置

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

tf.keras 训练 API 保存

请参阅 tf.keras 保存和恢复指南

tf.keras.Model.save_weights 可以保存一个 TensorFlow 检查点。

net.save_weights('easy_checkpoint')

编写检查点

TensorFlow 模型的持久状态存储在 tf.Variable 对象中。这些对象可以直接构造,但通常会通过像 tf.keras.layerstf.keras.Model 这样的高级 API 创建。

管理变量的最简单方法是将它们附加到 Python 对象,然后引用这些对象。

tf.train.Checkpointtf.keras.layers.Layertf.keras.Model 的子类会自动跟踪分配给其特性的变量。下面的示例构造了一个简单的线性模型,然后编写检查点,其中包含该模型所有变量的值。

您可以使用 Model.save_weights 轻松保存模型检查点。

手动创建检查点

设置

为了帮助演示 tf.train.Checkpoint 的所有功能, 下面定义了一个小数据集和优化步骤:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

创建检查点对象

要手动创建检查点,您将需要 tf.train.Checkpoint 对象。您想要为对象设置检查点的位置将设置为此对象的特性。

tf.train.CheckpointManager 也有助于管理多个检查点。

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

训练模型并为模型设置检查点

以下训练循环可创建模型和优化器的实例,然后将它们收集到 tf.train.Checkpoint 对象中。它在每批数据上循环调用训练步骤,并定期将检查点写入磁盘。

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 30.12
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.53
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 16.97
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.49
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.85

恢复和继续训练

在第一次设置检查点后,您可以传递新的模型和管理器,但需要从您离开的地方开始训练:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.51
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.77
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.74
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.42
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.23

tf.train.CheckpointManager 对象会删除旧的检查点。上面配置为仅保留最近的三个检查点。

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

这些路径(如 './tf_ckpts/ckpt-10')不是磁盘上的文件,而是一个 index 文件和一个或多个包含变量值的数据文件的前缀。这些前缀被分组到一个单独的 checkpoint 文件 ('./tf_ckpts/checkpoint') 中,其中 CheckpointManager 保存其状态。

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

加载机制

TensorFlow 通过从加载的对象开始遍历带命名边的有向计算图来将变量与检查点值匹配。边名称通常来自对象中的特性名称,例如 self.l1 = tf.keras.layers.Dense(5) 中的 "l1"tf.train.Checkpoint 使用其关键字参数名称,如 tf.train.Checkpoint(step=...) 中的 "step"

上面示例中的依赖图如下所示:

Visualization of the dependency graph for the example training loop

优化器为红色,常规变量为蓝色,优化器插槽变量为橙色。其他节点(例如代表 tf.train.Checkpoint 的节点)为黑色。

插槽变量是优化器状态的一部分,但是是为特定变量创建的。例如,上面的 'm' 边对应于动量,Adam 优化器会针对每个变量跟踪该动量。只有在同时保存变量和优化器时,才会将插槽变量保存到检查点中,并因此保存虚线边。

tf.train.Checkpoint 对象上调用 restore() 会对请求的恢复进行排队,一旦有来自 Checkpoint 对象的匹配路径,就会立即恢复变量值。例如,通过网络和层重构一个指向上面所定义模型的路径后,我们便可以仅加载该模型中的偏差。

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # We get the restored value now
[0. 0. 0. 0. 0.]
[3.4010847 2.9552095 2.5278587 3.8726504 4.8468237]

这些新对象的依赖图是我们在上面编写的较大检查点的更小子图。它只包含偏差和 tf.train.Checkpoint 用来计算检查点数量的保存计数器。

Visualization of a subgraph for the bias variable

restore() 返回一个状态对象,该对象具有可选的断言。我们在新 Checkpoint 中创建的所有对象均已恢复,因此 status.assert_existing_objects_matched() 传递。

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7ff9f85f36a0>

检查点中有许多尚未匹配的对象,包括层的内核和优化器的变量。status.assert_consumed() 仅在检查点和程序完全匹配时传递,并在此处引发异常。

延迟恢复

当输入形状可用时,TensorFlow 中的 Layer 对象可能会将变量创建延迟到变量的首次调用。例如,Dense 层内核的形状取决于该层的输入和输出形状,因此,作为构造函数参数所需的输出形状没有足够的信息来单独创建变量。由于调用 Layer 还会读取变量的值,必须在变量的创建与其首次使用之间进行恢复。

为支持这种习惯用法,tf.train.Checkpoint 会对尚不具有匹配变量的恢复进行排队。

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.502296  4.691994  4.9034066 4.8172593 4.8092093]]

手动检查检查点

tf.train.list_variables 可以列出检查点键和检查点中变量的形状。检查点键是上面显示的计算图中的路径。

tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

列表和字典跟踪

对于像 self.l1 = tf.keras.layers.Dense(5) 一样的直接特性赋值,将列表和字典分配给特性会跟踪其内容。

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

您可能会注意到列表和字典的包装器对象。这些包装器是可设置检查点版本的基础数据结构。就像基于特性的加载一样,这些包装器会在将变量添加到容器后立即恢复它的值。

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

相同的跟踪会自动应用于 tf.keras.Model 的子类,并且可用于跟踪层列表等用途。

使用 Estimator 保存基于对象的检查点

请参阅 Estimator 指南

默认情况下,Estimator 使用变量名而不是前面几部分中介绍的对象计算图来保存检查点。tf.train.Checkpoint 将接受基于名称的检查点,但是在将模型的一部分移到 Estimator 的 model_fn 外部时,变量名称可能会更改。保存基于对象的检查点可以更轻松地在 Estimator 内训练模型,然后在外部使用。

import tensorflow.compat.v1 as tf_compat
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.737563, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 40.91763.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7ffa3dc7dba8>

随后,tf.train.Checkpoint 可以从其 model_dir 加载 Estimator 的检查点。

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

总结

TensorFlow 对象提供了一种简单的自动机制来保存和恢复它们所使用变量的值。