TFDS 提供了一组现成的数据集,适合与 TensorFlow、Jax 和其他机器学习框架配合使用。
它可以确定地处理下载和准备数据并构造 tf.data.Dataset
(或 np.array
)。
注:不要将 TFDS(此库)与 tf.data
(用于构建高效数据流水线的 TensorFlow API)混淆。TFDS 是 tf.data
的高级封装容器。如果您不熟悉此 API,建议您先阅读官方 tf.data 指南。
版权所有 2018 TensorFlow 数据集作者,以 Apache License, Version 2.0 授权
在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
安装
TFDS 存在于两个软件包中:
pip install tensorflow-datasets
:稳定版,数月发行一次。pip install tfds-nightly
:每天发行,包含最近版本的数据集。
此 colab 使用 tfds-nightly
:
pip install -q tfds-nightly tensorflow matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
查找可用的数据集
所有数据集构建工具都是 tfds.core.DatasetBuilder
的子类。要获取可用构建工具的列表,请使用 tfds.list_builders()
或查看我们的目录。
tfds.list_builders()
加载数据集
tfds.load
加载数据集最简单的方法是 tfds.load
。它将执行以下操作:
- 下载数据并将其存储为
tfrecord
文件。 - 加载
tfrecord
并创建tf.data.Dataset
。
ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)
一些常见的参数:
split=
:要读取的拆分(例如'train'
、['train', 'test']
、'train[80%:]'
…)。请参阅我们的拆分 API 指南。shuffle_files=
:控制是否重排每个周期间的文件顺序(TFDS 以多个较小的文件存储大数据集)data_dir=
:数据集存储的位置(默认为~/tensorflow_datasets/
)with_info=True
:返回包含数据集元数据的tfds.core.DatasetInfo
download=False
:停用下载
tfds.builder
tfds.load
是 tfds.core.DatasetBuilder
的瘦封装容器。您可以使用 tfds.core.DatasetBuilder
API 获得相同的输出:
builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)
tfds build
CLI
如果您希望生成一个特定的数据集,可以使用 tfds
命令行。例如:
tfds build mnist
请参阅文档查看可用标志。
迭代数据集
作为字典
默认情况下,tf.data.Dataset
对象包含 tf.Tensor
的 dict
:
ds = tfds.load('mnist', split='train')
ds = ds.take(1) # Only take a single example
for example in ds: # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
print(list(example.keys()))
image = example["image"]
label = example["label"]
print(image.shape, label)
要找出 dict
键名和结构,请查看我们目录中的数据集文档。例如:mnist 文档。
作为元组(as_supervised=True
)
使用 as_supervised=True
,您可以获取 (features, label)
元组作为替代的监督数据集。
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)
for image, label in ds: # example is (image, label)
print(image.shape, label)
作为 numpy(tfds.as_numpy
)
使用 tfds.as_numpy
进行以下转换:
tf.Tensor
->np.array
tf.data.Dataset
->Iterator[Tree[np.array]]
(Tree
可能是任意嵌套的Dict
、Tuple
)
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)
for image, label in tfds.as_numpy(ds):
print(type(image), type(label), label)
作为 batched tf.Tensor(batch_size=-1
)
使用 batch_size=-1
,您可以在单个批次中加载完整的数据集。
这可与 as_supervised=True
和 tfds.as_numpy
结合使用以获取 (np.array, np.array)
形式的数据:
image, label = tfds.as_numpy(tfds.load(
'mnist',
split='test',
batch_size=-1,
as_supervised=True,
))
print(type(image), image.shape)
请注意,您的数据集可以放入内存,并且所有样本都具有相同的形状。
对您的数据集进行基准分析
对数据集进行基准分析是对任何可迭代对象(例如 tf.data.Dataset
、tfds.as_numpy
…)的简单 tfds.benchmark
调用。
ds = tfds.load('mnist', split='train')
ds = ds.batch(32).prefetch(1)
tfds.benchmark(ds, batch_size=32)
tfds.benchmark(ds, batch_size=32) # Second epoch much faster due to auto-caching
- 不要忘记使用
batch_size=
kwarg 对每个批次大小的结果进行归一化。 - 总之,第一个预热批次与其他预热批次分开以捕获
tf.data.Dataset
额外的设置时间(例如缓冲区初始化…)。 - 请注意,由于 TFDS 自动缓存功能,第二次迭代的速度要快得多。
tfds.benchmark
会返回tfds.core.BenchmarkResult
,可以检查它以进行进一步分析。
构建端到端流水线
要想深入一点,您可以查看:
- 我们的端到端 Keras 示例来了解完整的训练流水线(包括批处理、重排…)。
- 有助于提高流水线速度的性能指南(提示:使用
tfds.benchmark(ds)
对数据集进行基准分析)。
呈现
tfds.as_dataframe
使用 tfds.as_dataframe
,可以将 tf.data.Dataset
对象转换为 pandas.DataFrame
以在 Colab 上呈现。
- 添加
tfds.core.DatasetInfo
作为tfds.as_dataframe
的第二个参数以呈现图像、音频、文本、视频… - 使用
ds.take(x)
仅显示前x
个样本。pandas.DataFrame
将在内存中加载完整数据集,并且显示开销可能非常高。
ds, info = tfds.load('mnist', split='train', with_info=True)
tfds.as_dataframe(ds.take(4), info)
tfds.show_examples
tfds.show_examples
返回 matplotlib.figure.Figure
(现在只支持图像数据集):
ds, info = tfds.load('mnist', split='train', with_info=True)
fig = tfds.show_examples(ds, info)
访问数据集元数据
所有构建工具都包括一个包含数据集元数据的 tfds.core.DatasetInfo
对象。
可以通过以下方式访问:
tfds.load
API:
ds, info = tfds.load('mnist', with_info=True)
builder = tfds.builder('mnist')
info = builder.info
数据集信息包含有关数据集的附加信息(版本、引用、首页、描述…)。
print(info)
特征元数据(标签名称、图像形状…)
访问 tfds.features.FeatureDict
:
info.features
类、标签名的数量:
print(info.features["label"].num_classes)
print(info.features["label"].names)
print(info.features["label"].int2str(7)) # Human readable version (8 -> 'cat')
print(info.features["label"].str2int('7'))
形状、数据类型:
print(info.features.shape)
print(info.features.dtype)
print(info.features['image'].shape)
print(info.features['image'].dtype)
拆分元数据(例如拆分名称、样本数量…)
print(info.splits)
可用拆分:
print(list(info.splits.keys()))
获取有关个别拆分的信息:
print(info.splits['train'].num_examples)
print(info.splits['train'].filenames)
print(info.splits['train'].num_shards)
它也适用于 subsplit API:
print(info.splits['train[15%:75%]'].num_examples)
print(info.splits['train[15%:75%]'].file_instructions)
问题排查
手动下载(如果下载失败)
如果由于某种原因下载失败(例如离线…),那么您始终可以自己手动下载数据并将其放置在 manual_dir
中(默认为 ~/tensorflow_datasets/download/manual/
)。
要找到下载网址,请查看:
对于新数据集(作为文件夹实现):
tensorflow_datasets/
<type>/<dataset_name>/checksums.tsv
。例如:tensorflow_datasets/datasets/bool_q/checksums.tsv
。您可以在我们的目录中找到数据集的源位置。
修正 NonMatchingChecksumError
TFDS 通过验证下载网址的校验和来确保确定性。如果引发 NonMatchingChecksumError
,则可能表示:
- 网站可能宕机(如
503 status code
)。请检查网址。 - 对于 Google 云端硬盘网址,请稍后再试。当很多人访问同一网址时云端硬盘有时拒绝下载。请参阅错误
- 原始数据集文件可能已更新。在这种情况下,应当更新 TFDS 数据集构建工具。请打开一个新的 Github 议题或拉取请求:
- 使用
tfds build --register_checksums
注册新的校验和 - 逐步更新数据集生成代码。
- 更新数据集
VERSION
- 更新数据集
RELEASE_NOTES
:是什么导致校验和发生变化?一些样本发生了改变吗? - 确保数据集仍能够构建。
- 向我们发送拉取请求
- 使用
注:您也可以检查 ~/tensorflow_datasets/download/
中的下载文件。
引用
如果您在论文中使用 tensorflow-datasets
,除了特定于所用数据集(可以在数据集目录中找到)的任何引用之外,请包含以下引用。
@misc{TFDS,
title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
howpublished = {\url{https://tensorflow.google.cn/datasets} },
}