数据集快速入门

tf.data 模块包含一系列类,可让您轻松地加载数据、操作数据并通过管道将数据传送到模型中。本文档通过两个简单的示例来介绍该 API:

  • 从 Numpy 数组中读取内存中的数据。
  • 从 csv 文件中读取行。

基本输入

要开始使用 tf.data,最简单的方法是从数组中提取切片。

预创建的 Estimator 一章介绍了 iris_data.py 中的以下 train_input_fn,它可以通过管道将数据传输到 Estimator 中:

def train_input_fn(features, labels, batch_size):
    """An input function for training"""
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)

    # Build the Iterator, and return the read end of the pipeline.
    return dataset.make_one_shot_iterator().get_next()

我们来详细了解一下。

参数

此函数需要三个参数。要求所赋值为“数组”的参数能够接受可通过 numpy.array 转换成数组的几乎任何值。其中存在一个例外,即对 Datasets 有特殊意义的 tuple

  • features:包含原始输入特征的 {'feature_name':array} 字典(或 DataFrame)。
  • labels:包含每个样本的标签的数组。
  • batch_size:表示所需批次大小的整数。

premade_estimator.py 中,我们使用 iris_data.load_data() 函数检索了 Iris 数据。您可以运行该函数并解压结果,如下所示:

import iris_data

# Fetch the data
train, test = iris_data.load_data()
features, labels = train

然后,我们使用类似以下内容的行将此数据传递给了输入函数:

batch_size=100
iris_data.train_input_fn(features, labels, batch_size)

下面我们详细介绍一下 train_input_fn()

切片

在最简单的情况下,tf.data.Dataset.from_tensor_slices 函数接受一个数组并返回表示该数组切片的 tf.data.Dataset。例如,一个包含 mnist 训练数据的数组的形状为 (60000, 28, 28)。将该数组传递给 from_tensor_slices,会返回一个包含 60000 个切片的 Dataset 对象,其中每个切片都是一个 28x28 的图像。

返回此 Dataset 的代码如下所示:

train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train

mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print(mnist_ds)

这段代码将输出以下行,显示数据集中条目的形状类型。请注意,数据集不知道自己包含多少条目。

<TensorSliceDataset shapes: (28,28), types: tf.uint8>

上面的数据集表示了一组简单的数组,但实际的数据集要比这复杂得多。数据集以透明方式处理字典或元组的任何嵌套组合。例如,确保 features 是标准字典,然后您就可以将数组字典转换为字典 Dataset,如下所示:

dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print(dataset)
<TensorSliceDataset

  shapes: {
    SepalLength: (), PetalWidth: (),
    PetalLength: (), SepalWidth: ()},

  types: {
      SepalLength: tf.float64, PetalWidth: tf.float64,
      PetalLength: tf.float64, SepalWidth: tf.float64}
>

在这里,我们可以看到,如果 Dataset 包含结构化元素,则 Datasetshapestypes 将采用同一结构。此数据集包含所有类型为 tf.float64标量的字典。

train_input_fn 的第一行使用相同的功能,但添加了另一层结构。它会创建一个包含 (features, labels) 对的数据集。

以下代码显示标签是类型为 int64 的标量:

# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (), PetalWidth: (),
          PetalLength: (), SepalWidth: ()},
        ()),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

操作

目前,Dataset 会按固定顺序迭代数据一次,并且一次仅生成一个元素。它需要进一步处理才可用于训练。幸运的是,tf.data.Dataset 类提供了方法来更好地准备用于训练的数据。输入函数的下一行就利用了其中的几种方法:

# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)

shuffle 方法使用一个固定大小的缓冲区,在条目经过时执行随机化处理。将 buffer_size 设置为大于 Dataset 中样本数的值,可确保数据完全被随机化处理。Iris 数据集仅包含 150 个样本。

repeat 方法会在结束时重启 Dataset。要限制周期数量,请设置 count 参数。

batch 方法会收集大量样本并将它们堆叠起来以创建批次。这为批次的形状增加了一个维度。新的维度将添加为第一个维度。以下代码对之前的 MNIST Dataset 使用 batch 方法。这样会产生一个包含表示 (28,28) 图像堆叠的三维数组的 Dataset

print(mnist_ds.batch(100))
<BatchDataset
  shapes: (?, 28, 28),
  types: tf.uint8>

请注意,该数据集的批次大小是未知的,因为最后一个批次具有的元素数量会减少。

train_input_fn 中,经过批处理之后,Dataset 包含元素的一维向量,其中每个标量之前如下所示:

print(dataset)
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (?,), PetalWidth: (?,),
          PetalLength: (?,), SepalWidth: (?,)},
        (?,)),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

返回

每个 Estimator 的 trainevaluatepredict 方法都需要输入函数返回包含 TensorFlow 张量(features, label) 对。train_input_fn 使用以下行将数据集转换为所需格式:

# Build the Iterator, and return the read end of the pipeline.
features_result, labels_result = dataset.make_one_shot_iterator().get_next()

结果会生成与 Dataset 中条目的布局相匹配的 TensorFlow 张量结构。要简要了解这些对象及其使用方式,请参阅简介

print((features_result, labels_result))
({
    'SepalLength': <tf.Tensor 'IteratorGetNext:2' shape=(?,) dtype=float64>,
    'PetalWidth': <tf.Tensor 'IteratorGetNext:1' shape=(?,) dtype=float64>,
    'PetalLength': <tf.Tensor 'IteratorGetNext:0' shape=(?,) dtype=float64>,
    'SepalWidth': <tf.Tensor 'IteratorGetNext:3' shape=(?,) dtype=float64>},
Tensor("IteratorGetNext_1:4", shape=(?,), dtype=int64))

读取 CSV 文件

Dataset 类最常见的实际用例是流式传输磁盘上文件中的数据。tf.data 模块包含各种文件读取器。我们来看看如何使用 Dataset 解析 csv 文件中的 Iris 数据集。

iris_data.maybe_download 函数的以下调用会根据需要下载数据,并返回所生成文件的路径名:

import iris_data
train_path, test_path = iris_data.maybe_download()

iris_data.csv_input_fn 函数包含使用 Dataset 解析 csv 文件的备用实现。

我们来了解一下如何构建从本地文件读取数据且兼容 Estimator 的输入函数。

构建 Dataset

我们先构建一个 TextLineDataset 对象来实现一次读取文件中的一行数据。然后,我们调用 skip 方法来跳过文件的第一行,此行包含标题,而非样本:

ds = tf.data.TextLineDataset(train_path).skip(1)

构建 csv 行解析器

最终,我们需要解析数据集中的每一行,以生成必要的 (features, label) 对。

我们先构建一个函数来解析单行。

以下 iris_data.parse_line 函数使用 tf.decode_csv 函数和一些简单的 Python 代码完成此任务:

为了生成必要的 (features, label) 对,我们必须解析数据集中的每一行。以下 _parse_line 函数会调用 tf.decode_csv,以将单行解析为特征和标签两个部分。由于 Estimator 需要将特征表示为字典,因此我们依靠 Python 的内置 dictzip 函数来构建此字典。特征名称是该字典的键。然后,我们调用字典的 pop 方法以从特征字典中移除标签字段:

# Metadata describing the text columns
COLUMNS = ['SepalLength', 'SepalWidth',
           'PetalLength', 'PetalWidth',
           'label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
    # Decode the line into its fields
    fields = tf.decode_csv(line, FIELD_DEFAULTS)

    # Pack the result into a dictionary
    features = dict(zip(COLUMNS,fields))

    # Separate the label from the features
    label = features.pop('label')

    return features, label

解析行

数据集提供很多用于在通过管道将数据传送到模型的过程中处理数据的方法。最常用的方法是 map,它会对 Dataset 的每个元素应用转换。

map 方法会接受 map_func 参数,此参数描述了应该如何转换 Dataset 中的每个条目。

map 方法运用“map_func”来转换 Dataset 中的每个条目。

因此,为了在从 csv 文件中流式传出行时对行进行解析,我们将 _parse_line 函数传递给 map 方法:

ds = ds.map(_parse_line)
print(ds)
<MapDataset
shapes: (
    {SepalLength: (), PetalWidth: (), ...},
    ()),
types: (
    {SepalLength: tf.float32, PetalWidth: tf.float32, ...},
    tf.int32)>

现在,数据集包含 (features, label) 对,而不是简单的标量字符串。

iris_data.csv_input_fn 函数的剩余部分与 iris_data.train_input_fn 函数完全相同,后者在基本输入部分中进行了介绍。

试试看

此函数可用于替换 iris_data.train_input_fn。可使用此函数馈送 Estimator,如下所示:

train_path, test_path = iris_data.maybe_download()

# All the inputs are numeric
feature_columns = [
    tf.feature_column.numeric_column(name)
    for name in iris_data.CSV_COLUMN_NAMES[:-1]]

# Build the estimator
est = tf.estimator.LinearClassifier(feature_columns,
                                    n_classes=3)
# Train the estimator
batch_size = 100
est.train(
    steps=1000,
    input_fn=lambda : iris_data.csv_input_fn(train_path, batch_size))

Estimator 要求 input_fn 不接受任何参数。为了不受此限制约束,我们使用 lambda 来获取参数并提供所需的接口。

总结

tf.data 模块提供一系列类和函数,可用于轻松从各种来源读取数据。此外,tf.data 还提供简单而又强大的方法,用于应用各种标准和自定义转换。

现在,您已经基本了解了如何高效地将数据加载到 Estimator 中。接下来,请查看下列文档: