在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
TFDS 一直都独立于框架。例如,您可以轻松地加载 NumPy 格式的数据集以在 Jax 和 PyTorch 中使用。
TensorFlow 及其数据加载解决方案 (tf.data
) 按照设计是我们 API 中的一等公民。
我们扩展了 TFDS 以支持仅使用 NumPy 而无需 TensorFlow 的数据加载。这对于在 Jax 和 PyTorch 等机器学习框架中使用非常方便。事实上,对于后者的用户来说,TensorFlow:
- 会保留 GPU/TPU 内存;
- 会在 CI/CD 中增加构建时间;
- 在运行时需要花费时间导入。
TensorFlow 不再是读取数据集的依赖项。
机器学习流水线需要一个数据加载器来加载样本,将其解码并呈现给模型。数据加载器使用“源/采样器/加载器”范式:
TFDS dataset ┌────────────────┐
on disk │ │
┌──────────►│ Data │
|..|... │ | │ source ├─┐
├──┼────┴─────┤ │ │ │
│12│image12 │ └────────────────┘ │ ┌────────────────┐
├──┼──────────┤ │ │ │
│13│image13 │ ├───►│ Data ├───► ML pipeline
├──┼──────────┤ │ │ loader │
│14│image14 │ ┌────────────────┐ │ │ │
├──┼──────────┤ │ │ │ └────────────────┘
|..|... | │ Index ├─┘
│ sampler │
│ │
└────────────────┘
- 数据源负责实时访问和解码来自 TFDS 数据集的样本。
- 索引采样器负责确定记录处理的顺序。在读取任何记录之前,实现全局转换(例如全局重排、分片、重复多个周期)非常重要。
- 数据加载器通过利用数据源和索引采样器来编排加载。它可以实现性能优化(例如,预提取、多进程或多线程)。
速览
tfds.data_source
是一个用于创建数据源的 API:
- 用于纯 Python 流水线的快速原型设计;
- 用于大规模管理数据密集型机器学习流水线。
安装
让我们安装并导入所需依赖项:
!pip install array_record
!pip install tfds-nightly
import os
os.environ.pop('TFDS_DATA_DIR', None)
import tensorflow_datasets as tfds
数据源
数据源基本上是 Python 序列。因此,它们需要实现以下协议:
class RandomAccessDataSource(Protocol):
"""Interface for datasources where storage supports efficient random access."""
def __len__(self) -> int:
"""Number of records in the dataset."""
def __getitem__(self, record_key: int) -> Sequence[Any]:
"""Retrieves records for the given record_keys."""
警告:该 API 仍在积极开发中。特别是,__getitem__
目前在输入中必须支持 int
和 list[int]
。将来,按照标准,它可能仅支持 int
。
底层文件格式需要支持高效的随机访问。目前,TFDS 依赖于 array_record
。
array_record
是一种衍生自 Riegeli 的新文件格式,实现了 IO 效率的新前沿。特别是,ArrayRecord 支持按记录索引并行读取、写入和随机访问。ArrayRecord 建立在 Riegeli 之上,并支持相同的压缩算法。
fashion_mnist
是一个常见的计算机视觉数据集。要使用 TFDS 检索基于 ArrayRecord 的数据源,只需使用以下命令:
ds = tfds.data_source('fashion_mnist')
tfds.data_source
是一个方便的包装器。它等同于:
builder = tfds.builder('fashion_mnist', file_format='array_record')
builder.download_and_prepare()
ds = builder.as_data_source()
这将输出一个数据源字典:
{
'train': DataSource(name=fashion_mnist, split='train', decoders=None),
'test': DataSource(name=fashion_mnist, split='test', decoders=None),
}
一旦 download_and_prepare
运行并在您生成记录文件后,我们就不再需要 TensorFlow 了。一切都将在 Python/NumPy 中完成!
让我们通过卸载 TensorFlow 并在另一个子进程中重新加载数据源对此进行检查:
pip uninstall -y tensorflow
%%writefile no_tensorflow.py
import os
os.environ.pop('TFDS_DATA_DIR', None)
import tensorflow_datasets as tfds
try:
import tensorflow as tf
except ImportError:
print('No TensorFlow found...')
ds = tfds.data_source('fashion_mnist')
print('...but the data source could still be loaded...')
ds['train'][0]
print('...and the records can be decoded.')
python no_tensorflow.py
在未来的版本中,我们还将使数据集准备不再依赖 TensorFlow。
数据源的长度为:
len(ds['train'])
访问数据集的第一个元素:
%%timeit
ds['train'][0]
…开销与访问任何其他元素一样低。下面是随机访问的定义:
%%timeit
ds['train'][1000]
特征现在使用 NumPy DType(而不是 TensorFlow DType)。您可以使用以下命令检查特征:
features = tfds.builder('fashion_mnist').info.features
您可以在我们的文档中找到有关特征的更多信息。在这里,我们特别可以检索图像的形状和类别数量:
shape = features['image'].shape
num_classes = features['label'].num_classes
在纯 Python 中使用
您可以通过迭代来使用 Python 中的数据源:
for example in ds['train']:
print(example)
break
如果您检查元素,还会注意到所有特征都已使用 NumPy 解码。在幕后,我们默认使用 OpenCV,因为它很快。如果您没有安装 OpenCV,我们将默认使用 Pillow 来提供轻量级和快速的图像解码。
{
'image': array([[[0], [0], ..., [0]],
[[0], [0], ..., [0]]], dtype=uint8),
'label': 2,
}
注:目前,该功能仅适用于 Tensor
、Image
和 Scalar
特征。Audio
和 Video
特征即将推出。敬请关注!
与 PyTorch 结合使用
PyTorch 使用源/采样器/加载器范式。在 Torch 中,“数据源”称为“数据集”。torch.utils.data
包含构建高效输入流水线所需的所有详细信息。
TFDS 数据源可以像常规的映射样式数据集一样使用。
首先,我们安装并导入Torch:
!pip install torch
from tqdm import tqdm
import torch
我们已经为训练和测试分别定义了数据源(分别是 ds['train']
和 ds['test']
)。现在,我们可以定义采样器和加载器:
batch_size = 128
train_sampler = torch.utils.data.RandomSampler(ds['train'], num_samples=5_000)
train_loader = torch.utils.data.DataLoader(
ds['train'],
sampler=train_sampler,
batch_size=batch_size,
)
test_loader = torch.utils.data.DataLoader(
ds['test'],
sampler=None,
batch_size=batch_size,
)
使用 PyTorch,我们在第一个样本上进行训练,并评估简单的逻辑回归:
class LinearClassifier(torch.nn.Module):
def __init__(self, shape, num_classes):
super(LinearClassifier, self).__init__()
height, width, channels = shape
self.classifier = torch.nn.Linear(height * width * channels, num_classes)
def forward(self, image):
image = image.view(image.size()[0], -1).to(torch.float32)
return self.classifier(image)
model = LinearClassifier(shape, num_classes)
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()
print('Training...')
model.train()
for example in tqdm(train_loader):
image, label = example['image'], example['label']
prediction = model(image)
loss = loss_function(prediction, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Testing...')
model.eval()
num_examples = 0
true_positives = 0
for example in tqdm(test_loader):
image, label = example['image'], example['label']
prediction = model(image)
num_examples += image.shape[0]
predicted_label = prediction.argmax(dim=1)
true_positives += (predicted_label == label).sum().item()
print(f'\nAccuracy: {true_positives/num_examples * 100:.2f}%')
即将推出:与 JAX 结合使用
我们正在与 Grain 密切合作。Grain 是适用于 Python 的开源、快速和确定性数据加载器。敬请关注!
阅读更多内容
有关详情,请参阅 tfds.data_source
API 文档。