TFDS 和确定性

• TFDS 对于确定性的保证
• TFDS 读取样本的顺序
• 各种注意事项和陷阱

安装

数据集

``````import re
import tensorflow_datasets as tfds

imagenet = tfds.builder('imagenet2012')

num_shards = imagenet.info.splits['train'].num_shards
num_examples = imagenet.info.splits['train'].num_examples
print(f'imagenet has {num_shards} shards ({num_examples} examples)')
``````
```imagenet has 1024 shards (1281167 examples)
```

读取时的确定性

使用 `shuffle_files=False`（默认）

``````# Same as: imagenet.as_dataset(split='train').take(20)
print_ex_ids(imagenet, split='train', take=20)
print_ex_ids(imagenet, split='train', take=20)
``````
```[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
```

``````print_ex_ids(imagenet, split='train[67%:84%]', take=20)
print_ex_ids(imagenet, split='train[67%:84%]', take=20)
``````
```[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]
```

使用 `shuffle_files=True`

``````print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
``````
```[568017, 329050, 329051, 329052, 329053, 329054, 329056, 329055, 568019, 568020, 568021, 568022, 568023, 568018, 568025, 568024, 568026, 568028, 568030, 568031]
[43790, 43791, 43792, 43793, 43796, 43794, 43797, 43798, 43795, 43799, 43800, 43801, 43802, 43803, 43804, 43805, 43806, 43807, 43809, 43810]
```

确定性注意事项：交错参数

TFDS 会依赖 tf.data.Dataset.interleave 以每次仅加载几个分片，从而提高性能并减少内存使用。

• `cycle_length=16`, `block_length=16`（默认，同上）：
``````print_ex_ids(imagenet, split='train', take=20)
``````
```[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
```
• `cycle_length=3`, `block_length=2`
``````read_config = tfds.ReadConfig(
interleave_cycle_length=3,
interleave_block_length=2,
)
``````
```[0, 1, 1251, 1252, 2502, 2503, 2, 3, 1253, 1254, 2504, 2505, 4, 5, 1255, 1256, 2506, 2507, 6, 7]
```

子拆分和样本顺序

``````print_ex_ids(imagenet, split='train', take=25)  # tfds.load(..., split='train').take(25)
print_ex_ids(imagenet, split='train[:25]', take=-1)  # tfds.load(..., split='train[:25]')
``````
```[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
```

诀窍

获得确定性的文件顺序打乱

1. 设置 `shuffle_seed`。注：这需要在每个周期更改种子，否则在各个周期之间将以相同的顺序读取分片。
``````read_config = tfds.ReadConfig(
shuffle_seed=32,
)

# Deterministic order, different from the default shuffle_files=False above
``````
```[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
```
1. 使用 `experimental_interleave_sort_fn`：这可以完全控制读取哪些分片以及以什么顺序读取，而不依赖于 `ds.shuffle` 顺序。
``````def _reverse_order(file_instructions):
return list(reversed(file_instructions))

experimental_interleave_sort_fn=_reverse_order,
)

# Last shard (01023-of-01024) is read first
``````
```[1279916, 1279917, 1279918, 1279919, 1279920]
```

获得确定性的可抢占流水线

1. 在不使用 `ds.shuffle` 且使用确定性顺序打乱的情况下，理论上应该可以计算已读取的样本并推断每个分片中已读取了哪些样本（作为 `cycle_length``block_length` 和分片顺序的函数）。然后，可以通过 `experimental_interleave_sort_fn` 注入每个分片的 `skip``take`

2. 使用 `ds.shuffle` 时，通常不可能不重播完整的训练流水线。它需要保存 `ds.shuffle` 缓冲区状态以推断已读取哪些样本。样本可能是不连续的（例如读取了 `shard5_ex2`, `shard5_ex4`，但未读取 `shard5_ex3`）。

3. 使用 `ds.shuffle` 时，一种方式是保存所有读取的 shards_ids/example_ids（从 `tfds_id` 推断），然后从中推断出文件指令。

`1.` 项的最简单的情况是让 `.skip(x).take(y)` 匹配 `train[x:x+y]`。这需要：

• 设置 `cycle_length=1`（因此会按顺序读取分片）
• 设置 `shuffle_files=False`
• 不使用 `ds.shuffle`

``````read_config = tfds.ReadConfig(
)

# If the job get pre-empted, using the subsplit API will skip at most `len(shard0)`
``````
```[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]
```

查找针对给定子拆分读取哪些分片/样本

``````imagenet.info.splits['train[44%:45%]'].file_instructions
``````
```[FileInstruction(filename='imagenet2012-train.tfrecord-00450-of-01024', skip=700, take=-1, num_examples=551),
FileInstruction(filename='imagenet2012-train.tfrecord-00451-of-01024', skip=0, take=-1, num_examples=1251),
FileInstruction(filename='imagenet2012-train.tfrecord-00452-of-01024', skip=0, take=-1, num_examples=1251),
FileInstruction(filename='imagenet2012-train.tfrecord-00453-of-01024', skip=0, take=-1, num_examples=1251),
FileInstruction(filename='imagenet2012-train.tfrecord-00454-of-01024', skip=0, take=-1, num_examples=1252),
FileInstruction(filename='imagenet2012-train.tfrecord-00455-of-01024', skip=0, take=-1, num_examples=1251),
FileInstruction(filename='imagenet2012-train.tfrecord-00456-of-01024', skip=0, take=-1, num_examples=1251),
FileInstruction(filename='imagenet2012-train.tfrecord-00457-of-01024', skip=0, take=-1, num_examples=1251),
FileInstruction(filename='imagenet2012-train.tfrecord-00458-of-01024', skip=0, take=-1, num_examples=1251),
FileInstruction(filename='imagenet2012-train.tfrecord-00459-of-01024', skip=0, take=-1, num_examples=1251),
FileInstruction(filename='imagenet2012-train.tfrecord-00460-of-01024', skip=0, take=1001, num_examples=1001)]
```
[]
[]