此页面由 Cloud Translation API 翻译。
Switch to English

tf.train.Checkpoint

TensorFlow 1版 GitHub上查看源代码

组可跟踪对象,保存和恢复它们。

用在笔记本电脑

使用的指南使用教程

Checkpoint的构造器接受关键字参数,其值是包含可跟踪状态类型,如tf.keras.optimizers.Optimizer实现中, tf.Variable S, tf.data.Dataset迭代器, tf.keras.Layer实现中,或tf.keras.Model实现。它节省了这些值与检查点,并保持save_counter的编号检查点。

实例:

 import tensorflow as tf
import os

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

# Create a Checkpoint that will manage two objects with trackable state,
# one we name "optimizer" and the other we name "model".
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)
 

Checkpoint.save()Checkpoint.restore()写入和读取基于对象的检查点,而相比之下,TensorFlow 1.x中的tf.compat.v1.train.Saver其写入和读取variable.name基于关卡。基于对象的检查点保存Python对象之间的依赖关系(的图表Layer S, Optimizer S, Variable与名为边缘S,等),并且该图用于恢复的检查点时,以匹配变量。它可以是更稳健的Python程序的变化,并帮助支持恢复上创建变量。

Checkpoint对象对作为关键字参数传递给它们的构造函数的对象依赖关系,以及每个依赖被赋予一个名称,等同于为它创建的关键字参数的名称。 TensorFlow类,如Layer S和Optimizer旨意自己的变量(例如“核心”和“偏见”自动添加依赖tf.keras.layers.Dense )。继承tf.keras.Model使得在用户定义的类容易管理依赖,因为Model钩入属性分配。例如:

 class Regress(tf.keras.Model):

  def __init__(self):
    super(Regress, self).__init__()
    self.input_transform = tf.keras.layers.Dense(10)
    # ...

  def call(self, inputs):
    x = self.input_transform(inputs)
    # ...
 

Model有一个名为其对“input_transform”依赖Dense层,而这又取决于其变量。其结果是,在保存的情况下Regress使用tf.train.Checkpoint也将保存所有被创建的变量Dense层。

当变量被分配到多个工人,每个工人写自己的检查点的部分。然后将这些部分合并/重新索引表现为一个单一的检查点。这避免了复制所有变量来一个工人,但要求所有工人看到一个共同的文件系统。

虽然tf.keras.Model.save_weightstf.train.Checkpoint.save以相同的格式保存,注意所得到的检查点的根是保存方法附着到的对象。这意味着节约了tf.keras.Model使用save_weights和装载到tf.train.CheckpointModel附接(或反之亦然)将不匹配Model的变量。请参阅指导训练关卡的细节。身高tf.train.Checkpoint超过tf.keras.Model.save_weights训练关卡。

**kwargs 关键字参数被设置为这个对象的属性,并保存在检查点。值必须是可跟踪的对象。

ValueError 如果对象kwargs不追踪。

save_counter 当增量save()被调用。用于数字关卡。

方法

read

查看源代码

阅读书面培训检查点write

读此Checkpoint ,它依赖于任何对象。

这种方法就像restore()但并不指望save_counter在检查点变量。它只恢复检查点已所依赖的对象。

该方法主要用于通过使用更高级别的检查点管理实用程序,使用write() ,而不是save()有自己的机制,以数量和跟踪检查站。

实例:

 # Create a checkpoint with write()
ckpt = tf.train.Checkpoint(v=tf.Variable(1.))
path = ckpt.write('/tmp/my_checkpoint')

# Later, load the checkpoint with read()
# With restore() assert_consumed() would have failed.
checkpoint.read(path).assert_consumed()

# You can also pass options to restore(). For example this
# runs the IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.read(path, options=options)
 

ARGS
save_path 检查点的路径被返回的write
options 可选tf.train.CheckpointOptions对象。

返回
负载状态对象,其可用于做出关于检查点恢复的状态的断言。见restore的详细信息。

restore

查看源代码

恢复训练关卡。

恢复此Checkpoint ,它依赖于任何对象。

此方法的目的是用来通过创建负载检查点save() 。通过创建检查点write()使用read()不期望该方法save_counter通过添加变量save()

restore()或者赋值立即如果要恢复已经被创建的变量,或推迟恢复,直到创建的变量。如果它们具有在检查点的相应对象(恢复请求将在任何可跟踪对象排队等候要添加的预期依赖性)这个调用后添加的依赖性将被匹配。

为了确保加载完成后并没有更多的分配将发生,使用assert_consumed()返回的状态对象的方法, restore()

 checkpoint = tf.train.Checkpoint( ... )
checkpoint.restore(path).assert_consumed()

# You can additionally pass options to restore():
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.restore(path, options=options).assert_consumed()
 

如果任何的Python的依赖关系图对象在检查站均未发现,或者如果有检查点的值没有匹配的Python对象的,将引发异常。

基于名称的tf.compat.v1.train.Saver从TensorFlow 1.x的检查站可以使用这种方法来加载。名称是用来匹配变量。基于名称重新编码检查站使用tf.train.Checkpoint.save尽快。

ARGS
save_path 检查点的路径,通过返回savetf.train.latest_checkpoint 。如果检查点被写了基于域名的tf.compat.v1.train.Saver ,名称是用来匹配变量。
options 可选tf.train.CheckpointOptions对象。

返回
负载状态对象,其可用于做出关于检查点恢复的状态的断言。

返回的状态对象有以下方法:

  • assert_consumed()如果任何变量是无与伦比的引发一个例外:其不具有匹配Python对象或Python依赖图与在检查点没有值的对象或者检查点值。此方法返回状态对象,并且因此可以与其他的断言被链接。

  • assert_existing_objects_matched()如果任何现有的Python依赖图中是不匹配的对象引发一个例外。不像assert_consumed ,这种说法会通过,如果在检查点值都没有相应的Python对象。例如,一个tf.keras.Layer还未建成,所以并没有创造任何变量对象,将通过这一说法,但未能assert_consumed 。有用的,当一个大检查站进入一个新的Python程序的装载部分,具有如训练关卡tf.compat.v1.train.Optimizer被保存,但只为推断所需要的状态正在装载。此方法返回状态对象,并且因此可以与其他的断言被链接。

  • assert_nontrivial_match()断言从根对象的东西放在一边被匹配。这是一个非常薄弱的​​说法,但对于在库中的代码,其中对象可能还没有在Python和一些Python对象被创建可能没有一个检查点值检查点存在完整性检查是有用的。

  • expect_partial()不完整的检查点恢复沉默警告。警告否则打印的检查点文件或对象的未使用的部分,当Checkpoint对象被删除(经常程序关机)。

save

查看源代码

节省了培训的检查点,并提供基本的检查站的管理。

保存的检查点包括由该对象创建的变量和任何可追踪对象它的时候取决于Checkpoint.save()被调用。

save是围绕一个基本便利的包装write方法,依次编号检查站使用save_counter和更新所使用的元数据tf.train.latest_checkpoint 。更先进的检查站的管理,例如垃圾收集和自定义编号,可以通过其他工具还提供包writeread 。 ( tf.train.CheckpointManager例如)。

 step = tf.Variable(0, name="step")
checkpoint = tf.Checkpoint(step=step)
checkpoint.save("/tmp/ckpt")

# Later, read the checkpoint with restore()
checkpoint.restore("/tmp/ckpt").assert_consumed()

# You can also pass options to save() and restore(). For example this
# runs the IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.save("/tmp/ckpt", options=options)

# Later, read the checkpoint with restore()
checkpoint.restore("/tmp/ckpt", options=options).assert_consumed()
 

ARGS
file_prefix 前缀用于检查点的文件名(/路径/到/目录/ and_a_prefix)。名称是基于这个前缀和产生Checkpoint.save_counter
options 可选tf.train.CheckpointOptions对象。

返回
的完整路径检查点。

write

查看源代码

写一个训练关卡。

检查点包括由该对象创建的变量和任何可跟踪对象在它的时间取决于Checkpoint.write()被调用。

write不号检查站,增加save_counter ,或更新所使用的元数据tf.train.latest_checkpoint 。它主要用于由更高级别的检查站管理工具的使用。 save提供了这些功能非常基本实现。

书面检查点write必须读取read

实例:

 step = tf.Variable(0, name="step")
checkpoint = tf.Checkpoint(step=step)
checkpoint.write("/tmp/ckpt")

# Later, read the checkpoint with read()
checkpoint.read("/tmp/ckpt").assert_consumed()

# You can also pass options to write() and read(). For example this
# runs the IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.write("/tmp/ckpt", options=options)

# Later, read the checkpoint with read()
checkpoint.read("/tmp/ckpt", options=options).assert_consumed()
 

ARGS
file_prefix 前缀用于检查点的文件名(/路径/到/目录/ and_a_prefix)。
options 可选tf.train.CheckpointOptions对象。

返回
的完整路径检查点(即file_prefix )。