适用于 Jax 和 PyTorch 的 TFDS

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

TFDS 一直都独立于框架。例如,您可以轻松地加载 NumPy 格式的数据集以在 Jax 和 PyTorch 中使用。

TensorFlow 及其数据加载解决方案 (tf.data) 按照设计是我们 API 中的一等公民。

我们扩展了 TFDS 以支持仅使用 NumPy 而无需 TensorFlow 的数据加载。这对于在 Jax 和 PyTorch 等机器学习框架中使用非常方便。事实上,对于后者的用户来说,TensorFlow:

  • 会保留 GPU/TPU 内存;
  • 会在 CI/CD 中增加构建时间;
  • 在运行时需要花费时间导入。

TensorFlow 不再是读取数据集的依赖项。

机器学习流水线需要一个数据加载器来加载样本,将其解码并呈现给模型。数据加载器使用“源/采样器/加载器”范式:

 TFDS dataset       ┌────────────────┐
   on disk                          
        ┌──────────►│      Data      
|..|...      |          source     ├─┐
├──┼────┴─────┤                      
12image12        └────────────────┘     ┌────────────────┐
├──┼──────────┤                                            
13image13                           ├───►│      Data      ├───► ML pipeline
├──┼──────────┤                                 loader     
14image14        ┌────────────────┐                     
├──┼──────────┤                          └────────────────┘
|..|...       |          Index      ├─┘
                        sampler     
                                    
                    └────────────────┘
  • 数据源负责实时访问和解码来自 TFDS 数据集的样本。
  • 索引采样器负责确定记录处理的顺序。在读取任何记录之前,实现全局转换(例如全局重排、分片、重复多个周期)非常重要。
  • 数据加载器通过利用数据源和索引采样器来编排加载。它可以实现性能优化(例如,预提取、多进程或多线程)。

速览

tfds.data_source 是一个用于创建数据源的 API:

  1. 用于纯 Python 流水线的快速原型设计;
  2. 用于大规模管理数据密集型机器学习流水线。

安装

让我们安装并导入所需依赖项:

!pip install array_record
!pip install tfds-nightly

import os
os.environ.pop('TFDS_DATA_DIR', None)

import tensorflow_datasets as tfds

数据源

数据源基本上是 Python 序列。因此,它们需要实现以下协议:

class RandomAccessDataSource(Protocol):
  """Interface for datasources where storage supports efficient random access."""

  def __len__(self) -> int:
    """Number of records in the dataset."""

  def __getitem__(self, record_key: int) -> Sequence[Any]:
    """Retrieves records for the given record_keys."""

警告:该 API 仍在积极开发中。特别是,__getitem__ 目前在输入中必须支持 intlist[int]。将来,按照标准,它可能仅支持 int

底层文件格式需要支持高效的随机访问。目前,TFDS 依赖于 array_record

array_record 是一种衍生自 Riegeli 的新文件格式,实现了 IO 效率的新前沿。特别是,ArrayRecord 支持按记录索引并行读取、写入和随机访问。ArrayRecord 建立在 Riegeli 之上,并支持相同的压缩算法。

fashion_mnist 是一个常见的计算机视觉数据集。要使用 TFDS 检索基于 ArrayRecord 的数据源,只需使用以下命令:

ds = tfds.data_source('fashion_mnist')

tfds.data_source 是一个方便的包装器。它等同于:

builder = tfds.builder('fashion_mnist', file_format='array_record')
builder.download_and_prepare()
ds = builder.as_data_source()

这将输出一个数据源字典:

{
  'train': DataSource(name=fashion_mnist, split='train', decoders=None),
  'test': DataSource(name=fashion_mnist, split='test', decoders=None),
}

一旦 download_and_prepare 运行并在您生成记录文件后,我们就不再需要 TensorFlow 了。一切都将在 Python/NumPy 中完成!

让我们通过卸载 TensorFlow 并在另一个子进程中重新加载数据源对此进行检查:

pip uninstall -y tensorflow
%%writefile no_tensorflow.py
import os
os.environ.pop('TFDS_DATA_DIR', None)

import tensorflow_datasets as tfds

try:
  import tensorflow as tf
except ImportError:
  print('No TensorFlow found...')

ds = tfds.data_source('fashion_mnist')
print('...but the data source could still be loaded...')
ds['train'][0]
print('...and the records can be decoded.')
python no_tensorflow.py

在未来的版本中,我们还将使数据集准备不再依赖 TensorFlow。

数据源的长度为:

len(ds['train'])

访问数据集的第一个元素:

%%timeit
ds['train'][0]

…开销与访问任何其他元素一样低。下面是随机访问的定义:

%%timeit
ds['train'][1000]

特征现在使用 NumPy DType(而不是 TensorFlow DType)。您可以使用以下命令检查特征:

features = tfds.builder('fashion_mnist').info.features

您可以在我们的文档中找到有关特征的更多信息。在这里,我们特别可以检索图像的形状和类别数量:

shape = features['image'].shape
num_classes = features['label'].num_classes

在纯 Python 中使用

您可以通过迭代来使用 Python 中的数据源:

for example in ds['train']:
  print(example)
  break

如果您检查元素,还会注意到所有特征都已使用 NumPy 解码。在幕后,我们默认使用 OpenCV,因为它很快。如果您没有安装 OpenCV,我们将默认使用 Pillow 来提供轻量级和快速的图像解码。

{
  'image': array([[[0], [0], ..., [0]],
                  [[0], [0], ..., [0]]], dtype=uint8),
  'label': 2,
}

:目前,该功能仅适用于 TensorImageScalar 特征。AudioVideo 特征即将推出。敬请关注!

与 PyTorch 结合使用

PyTorch 使用源/采样器/加载器范式。在 Torch 中,“数据源”称为“数据集”。torch.utils.data 包含构建高效输入流水线所需的所有详细信息。

TFDS 数据源可以像常规的映射样式数据集一样使用。

首先,我们安装并导入Torch:

!pip install torch

from tqdm import tqdm
import torch

我们已经为训练和测试分别定义了数据源(分别是 ds['train']ds['test'])。现在,我们可以定义采样器和加载器:

batch_size = 128
train_sampler = torch.utils.data.RandomSampler(ds['train'], num_samples=5_000)
train_loader = torch.utils.data.DataLoader(
    ds['train'],
    sampler=train_sampler,
    batch_size=batch_size,
)
test_loader = torch.utils.data.DataLoader(
    ds['test'],
    sampler=None,
    batch_size=batch_size,
)

使用 PyTorch,我们在第一个样本上进行训练,并评估简单的逻辑回归:

class LinearClassifier(torch.nn.Module):
  def __init__(self, shape, num_classes):
    super(LinearClassifier, self).__init__()
    height, width, channels = shape
    self.classifier = torch.nn.Linear(height * width * channels, num_classes)

  def forward(self, image):
    image = image.view(image.size()[0], -1).to(torch.float32)
    return self.classifier(image)


model = LinearClassifier(shape, num_classes)
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()

print('Training...')
model.train()
for example in tqdm(train_loader):
  image, label = example['image'], example['label']
  prediction = model(image)
  loss = loss_function(prediction, label)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

print('Testing...')
model.eval()
num_examples = 0
true_positives = 0
for example in tqdm(test_loader):
  image, label = example['image'], example['label']
  prediction = model(image)
  num_examples += image.shape[0]
  predicted_label = prediction.argmax(dim=1)
  true_positives += (predicted_label == label).sum().item()
print(f'\nAccuracy: {true_positives/num_examples * 100:.2f}%')

即将推出:与 JAX 结合使用

我们正在与 Grain 密切合作。Grain 是适用于 Python 的开源、快速和确定性数据加载器。敬请关注!

阅读更多内容

有关详情,请参阅 tfds.data_source API 文档。