TensorFlow Datasets

TFDS 提供了一组现成的数据集,适合与 TensorFlow、Jax 和其他机器学习框架配合使用。

它可以确定地处理下载和准备数据并构造 tf.data.Dataset(或 np.array)。

注:不要将 TFDS(此库)与 tf.data(用于构建高效数据流水线的 TensorFlow API)混淆。TFDS 是 tf.data 的高级封装容器。如果您不熟悉此 API,建议您先阅读官方 tf.data 指南

版权所有 2018 TensorFlow 数据集作者,以 Apache License, Version 2.0 授权

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

安装

TFDS 存在于两个软件包中:

  • pip install tensorflow-datasets:稳定版,数月发行一次。
  • pip install tfds-nightly:每天发行,包含最近版本的数据集。

此 colab 使用 tfds-nightly

pip install -q tfds-nightly tensorflow matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

查找可用的数据集

所有数据集构建工具都是 tfds.core.DatasetBuilder 的子类。要获取可用构建工具的列表,请使用 tfds.list_builders() 或查看我们的目录

tfds.list_builders()

加载数据集

tfds.load

加载数据集最简单的方法是 tfds.load。它将执行以下操作:

  1. 下载数据并将其存储为 tfrecord 文件。
  2. 加载 tfrecord 并创建 tf.data.Dataset
ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)

一些常见的参数:

  • split=:要读取的拆分(例如 'train'['train', 'test']'train[80%:]'…)。请参阅我们的拆分 API 指南
  • shuffle_files=:控制是否重排每个周期间的文件顺序(TFDS 以多个较小的文件存储大数据集)
  • data_dir=:数据集存储的位置(默认为 ~/tensorflow_datasets/
  • with_info=True:返回包含数据集元数据的 tfds.core.DatasetInfo
  • download=False:停用下载

tfds.builder

tfds.loadtfds.core.DatasetBuilder 的瘦封装容器。您可以使用 tfds.core.DatasetBuilder API 获得相同的输出:

builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)

tfds build CLI

如果您希望生成一个特定的数据集,可以使用 tfds 命令行。例如:

tfds build mnist

请参阅文档查看可用标志。

迭代数据集

作为字典

默认情况下,tf.data.Dataset 对象包含 tf.Tensordict

ds = tfds.load('mnist', split='train')
ds = ds.take(1)  # Only take a single example

for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
  print(list(example.keys()))
  image = example["image"]
  label = example["label"]
  print(image.shape, label)

要找出 dict 键名和结构,请查看我们目录中的数据集文档。例如:mnist 文档

作为元组(as_supervised=True

使用 as_supervised=True,您可以获取 (features, label) 元组作为替代的监督数据集。

ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in ds:  # example is (image, label)
  print(image.shape, label)

作为 numpy(tfds.as_numpy

使用 tfds.as_numpy 进行以下转换:

ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in tfds.as_numpy(ds):
  print(type(image), type(label), label)

作为 batched tf.Tensor(batch_size=-1

使用 batch_size=-1,您可以在单个批次中加载完整的数据集。

这可与 as_supervised=Truetfds.as_numpy 结合使用以获取 (np.array, np.array) 形式的数据:

image, label = tfds.as_numpy(tfds.load(
    'mnist',
    split='test',
    batch_size=-1,
    as_supervised=True,
))

print(type(image), image.shape)

请注意,您的数据集可以放入内存,并且所有样本都具有相同的形状。

对您的数据集进行基准分析

对数据集进行基准分析是对任何可迭代对象(例如 tf.data.Datasettfds.as_numpy…)的简单 tfds.benchmark 调用。

ds = tfds.load('mnist', split='train')
ds = ds.batch(32).prefetch(1)

tfds.benchmark(ds, batch_size=32)
tfds.benchmark(ds, batch_size=32)  # Second epoch much faster due to auto-caching
  • 不要忘记使用 batch_size= kwarg 对每个批次大小的结果进行归一化。
  • 总之,第一个预热批次与其他预热批次分开以捕获 tf.data.Dataset 额外的设置时间(例如缓冲区初始化…)。
  • 请注意,由于 TFDS 自动缓存功能,第二次迭代的速度要快得多。
  • tfds.benchmark 会返回 tfds.core.BenchmarkResult ,可以检查它以进行进一步分析。

构建端到端流水线

要想深入一点,您可以查看:

呈现

tfds.as_dataframe

使用 tfds.as_dataframe,可以将 tf.data.Dataset 对象转换为 pandas.DataFrame 以在 Colab 上呈现。

  • 添加 tfds.core.DatasetInfo 作为 tfds.as_dataframe 的第二个参数以呈现图像、音频、文本、视频…
  • 使用 ds.take(x) 仅显示前 x 个样本。pandas.DataFrame 将在内存中加载完整数据集,并且显示开销可能非常高。
ds, info = tfds.load('mnist', split='train', with_info=True)

tfds.as_dataframe(ds.take(4), info)

tfds.show_examples

tfds.show_examples 返回 matplotlib.figure.Figure(现在只支持图像数据集):

ds, info = tfds.load('mnist', split='train', with_info=True)

fig = tfds.show_examples(ds, info)

访问数据集元数据

所有构建工具都包括一个包含数据集元数据的 tfds.core.DatasetInfo 对象。

可以通过以下方式访问:

ds, info = tfds.load('mnist', with_info=True)
builder = tfds.builder('mnist')
info = builder.info

数据集信息包含有关数据集的附加信息(版本、引用、首页、描述…)。

print(info)

特征元数据(标签名称、图像形状…)

访问 tfds.features.FeatureDict

info.features

类、标签名的数量:

print(info.features["label"].num_classes)
print(info.features["label"].names)
print(info.features["label"].int2str(7))  # Human readable version (8 -> 'cat')
print(info.features["label"].str2int('7'))

形状、数据类型:

print(info.features.shape)
print(info.features.dtype)
print(info.features['image'].shape)
print(info.features['image'].dtype)

拆分元数据(例如拆分名称、样本数量…)

访问 tfds.core.SplitDict

print(info.splits)

可用拆分:

print(list(info.splits.keys()))

获取有关个别拆分的信息:

print(info.splits['train'].num_examples)
print(info.splits['train'].filenames)
print(info.splits['train'].num_shards)

它也适用于 subsplit API:

print(info.splits['train[15%:75%]'].num_examples)
print(info.splits['train[15%:75%]'].file_instructions)

问题排查

手动下载(如果下载失败)

如果由于某种原因下载失败(例如离线…),那么您始终可以自己手动下载数据并将其放置在 manual_dir 中(默认为 ~/tensorflow_datasets/download/manual/)。

要找到下载网址,请查看:

修正 NonMatchingChecksumError

TFDS 通过验证下载网址的校验和来确保确定性。如果引发 NonMatchingChecksumError,则可能表示:

  • 网站可能宕机(如 503 status code)。请检查网址。
  • 对于 Google 云端硬盘网址,请稍后再试。当很多人访问同一网址时云端硬盘有时拒绝下载。请参阅错误
  • 原始数据集文件可能已更新。在这种情况下,应当更新 TFDS 数据集构建工具。请打开一个新的 Github 议题或拉取请求:
    • 使用 tfds build --register_checksums 注册新的校验和
    • 逐步更新数据集生成代码。
    • 更新数据集 VERSION
    • 更新数据集 RELEASE_NOTES:是什么导致校验和发生变化?一些样本发生了改变吗?
    • 确保数据集仍能够构建。
    • 向我们发送拉取请求

注:您也可以检查 ~/tensorflow_datasets/download/ 中的下载文件。

引用

如果您在论文中使用 tensorflow-datasets,除了特定于所用数据集(可以在数据集目录中找到)的任何引用之外,请包含以下引用。

@misc{TFDS,
  title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
  howpublished = {\url{https://tensorflow.google.cn/datasets} },
}