TFDS 和确定性

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

本文将说明:

  • TFDS 对于确定性的保证
  • TFDS 读取样本的顺序
  • 各种注意事项和陷阱

安装

数据集

理解 TFDS 如何读取数据需要一些上下文。

在生成过程中,TFDS 会将原始数据写入标准化 .tfrecord 文件。对于大数据集,会创建多个 .tfrecord 文件,每个文件都包含多个样本。我们将每个 .tfrecord 文件称为分片

本指南使用了具有 1024 个分片的 ImageNet:

import re
import tensorflow_datasets as tfds

imagenet = tfds.builder('imagenet2012')

num_shards = imagenet.info.splits['train'].num_shards
num_examples = imagenet.info.splits['train'].num_examples
print(f'imagenet has {num_shards} shards ({num_examples} examples)')
imagenet has 1024 shards (1281167 examples)

查找数据集样本 ID

如果您只想了解确定性,那么可以跳到下一部分。

每个数据集样本都由一个 id(例如 'imagenet2012-train.tfrecord-01023-of-01024__32')唯一标识。您可以通过传递 read_config.add_tfds_id = True 来恢复此 id,这将在 tf.data.Dataset 的字典中添加一个 'tfds_id' 键。

在本教程中,我们定义了一个用于打印数据集的样本 ID(转换为整数以更加易读)的小工具:

读取时的确定性

本部分将解释 tfds.load 的确定性保证。

使用 shuffle_files=False(默认)

默认情况下,TFDS 会以确定性的方式产生样本 (shuffle_files=False)

# Same as: imagenet.as_dataset(split='train').take(20)
print_ex_ids(imagenet, split='train', take=20)
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]

为了提高性能,TFDS 会使用 tf.data.Dataset.interleave 同时读取多个分片。我们在这个示例中看到,TFDS 在读取了 16 个样本后切换到了分片 2 (..., 14, 15, 1251, 1252, ...)。下文的交错部分提供了更多信息。

同样,subsplit API 也是确定性的:

print_ex_ids(imagenet, split='train[67%:84%]', take=20)
print_ex_ids(imagenet, split='train[67%:84%]', take=20)
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]

如果您要训练多个周期,那么不建议使用以上设置,因为所有周期都会以相同的顺序读取分片(因此随机性仅限于 ds = ds.shuffle(buffer) 缓冲区大小)。

使用 shuffle_files=True

使用 shuffle_files=True 时,每个周期都会打乱分片的顺序,因此读取不再具有确定性。

print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
[568017, 329050, 329051, 329052, 329053, 329054, 329056, 329055, 568019, 568020, 568021, 568022, 568023, 568018, 568025, 568024, 568026, 568028, 568030, 568031]
[43790, 43791, 43792, 43793, 43796, 43794, 43797, 43798, 43795, 43799, 43800, 43801, 43802, 43803, 43804, 43805, 43806, 43807, 43809, 43810]

注:设置 shuffle_files=True 还会停用 tf.data.Options 中的 deterministic 以提高性能。因此,即使只有一个分片的小型数据集(如 MNIST)也会变得不确定。

请参阅下方诀窍部分以获得确定性的文件打乱。

确定性注意事项:交错参数

更改 read_config.interleave_cycle_length, read_config.interleave_block_length 将更改样本顺序。

TFDS 会依赖 tf.data.Dataset.interleave 以每次仅加载几个分片,从而提高性能并减少内存使用。

仅在交错参数为固定值的情况下,才能保证样本顺序相同。请参阅交错文档以了解 cycle_lengthblock_length 对应的内容。

  • cycle_length=16, block_length=16(默认,同上):
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
  • cycle_length=3, block_length=2
read_config = tfds.ReadConfig(
    interleave_cycle_length=3,
    interleave_block_length=2,
)
print_ex_ids(imagenet, split='train', read_config=read_config, take=20)
[0, 1, 1251, 1252, 2502, 2503, 2, 3, 1253, 1254, 2504, 2505, 4, 5, 1255, 1256, 2506, 2507, 6, 7]

在第二个示例中,我们看到数据集读取了一个分片中的 2 个 (block_length=2) 个样本,然后切换到下一个分片。每隔 2 * 3 (cycle_length=3) 个样本,它就会回到第一个分片 (shard0-ex0, shard0-ex1, shard1-ex0, shard1-ex1, shard2-ex0, shard2-ex1, shard0-ex2, shard0-ex3, shard1-ex2, shard1-ex3, shard2-ex2,...)。

子拆分和样本顺序

每个样本都有一个 ID 0, 1, ..., num_examples-1subsplit API 会选择一个样本切片(例如 train[:x] 会选择 0, 1, ..., x-1)。

但是在子拆分中,不会按照 ID 递增顺序读取样本(由于分片和交错)。

更具体地说,ds.take(x) 并不等同于 split='train[:x]'

在上面的交错示例中可以很容易地看出这点,其中的样本来自不同的分片。

print_ex_ids(imagenet, split='train', take=25)  # tfds.load(..., split='train').take(25)
print_ex_ids(imagenet, split='train[:25]', take=-1)  # tfds.load(..., split='train[:25]')
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]

在 16 个(block_length)样本之后,.take(25) 切换到下一个分片,而 train[:25] 却在继续读取第一个分片中的样本。

诀窍

获得确定性的文件顺序打乱

可以通过两种方式进行确定性的顺序打乱:

  1. 设置 shuffle_seed。注:这需要在每个周期更改种子,否则在各个周期之间将以相同的顺序读取分片。
read_config = tfds.ReadConfig(
    shuffle_seed=32,
)

# Deterministic order, different from the default shuffle_files=False above
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
  1. 使用 experimental_interleave_sort_fn:这可以完全控制读取哪些分片以及以什么顺序读取,而不依赖于 ds.shuffle 顺序。
def _reverse_order(file_instructions):
  return list(reversed(file_instructions))

read_config = tfds.ReadConfig(
    experimental_interleave_sort_fn=_reverse_order,
)

# Last shard (01023-of-01024) is read first
print_ex_ids(imagenet, split='train', read_config=read_config, take=5)
[1279916, 1279917, 1279918, 1279919, 1279920]

获得确定性的可抢占流水线

此问题更加复杂。没有简单且令人满意的解决方案。

  1. 在不使用 ds.shuffle 且使用确定性顺序打乱的情况下,理论上应该可以计算已读取的样本并推断每个分片中已读取了哪些样本(作为 cycle_lengthblock_length 和分片顺序的函数)。然后,可以通过 experimental_interleave_sort_fn 注入每个分片的 skiptake

  2. 使用 ds.shuffle 时,通常不可能不重播完整的训练流水线。它需要保存 ds.shuffle 缓冲区状态以推断已读取哪些样本。样本可能是不连续的(例如读取了 shard5_ex2, shard5_ex4,但未读取 shard5_ex3)。

  3. 使用 ds.shuffle 时,一种方式是保存所有读取的 shards_ids/example_ids(从 tfds_id 推断),然后从中推断出文件指令。

1. 项的最简单的情况是让 .skip(x).take(y) 匹配 train[x:x+y]。这需要:

  • 设置 cycle_length=1(因此会按顺序读取分片)
  • 设置 shuffle_files=False
  • 不使用 ds.shuffle

它只能用于训练只有 1 个周期的大型数据集。将以默认的随机顺序读取样本。

read_config = tfds.ReadConfig(
    interleave_cycle_length=1,  # Read shards sequentially
)

print_ex_ids(imagenet, split='train', read_config=read_config, skip=40, take=22)
# If the job get pre-empted, using the subsplit API will skip at most `len(shard0)`
print_ex_ids(imagenet, split='train[40:]', read_config=read_config, take=22)
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]

查找针对给定子拆分读取哪些分片/样本

使用 tfds.core.DatasetInfo,您可以直接访问读取指令。

imagenet.info.splits['train[44%:45%]'].file_instructions
[FileInstruction(filename='imagenet2012-train.tfrecord-00450-of-01024', skip=700, take=-1, num_examples=551),
 FileInstruction(filename='imagenet2012-train.tfrecord-00451-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00452-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00453-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00454-of-01024', skip=0, take=-1, num_examples=1252),
 FileInstruction(filename='imagenet2012-train.tfrecord-00455-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00456-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00457-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00458-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00459-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00460-of-01024', skip=0, take=1001, num_examples=1001)]