Missed TensorFlow World? Check out the recap. Learn more

预创建的 Estimators

在 tensorFlow.google.cn 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载 notebook

本教程将向您展示如何使用 Estimators 解决 Tensorflow 中的鸢尾花(Iris)分类问题。Estimator 是 Tensorflow 完整模型的高级表示,它被设计用于轻松扩展和异步训练。更多细节请参阅 Estimators

请注意,在 Tensorflow 2.0 中,Keras API 可以完成许多相同的任务,而且被认为是一个更易学习的API。如果您刚刚开始入门,我们建议您从 Keras 开始。有关 Tensorflow 2.0 中可用高级API的更多信息,请参阅 Keras标准化

首先要做的事

为了开始,您将首先导入 Tensorflow 和一系列您需要的库。

from __future__ import absolute_import, division, print_function, unicode_literals


import tensorflow as tf

import pandas as pd

数据集

本文档中的示例程序构建并测试了一个模型,该模型根据花萼花瓣的大小将鸢尾花分成三种物种。

您将使用鸢尾花数据集训练模型。该数据集包括四个特征和一个标签。这四个特征确定了单个鸢尾花的以下植物学特征:

  • 花萼长度
  • 花萼宽度
  • 花瓣长度
  • 花瓣宽度

根据这些信息,您可以定义一些有用的常量来解析数据:

CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']

接下来,使用 Keras 与 Pandas 下载并解析鸢尾花数据集。注意为训练和测试保留不同的数据集。

train_path = tf.keras.utils.get_file(
    "iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
test_path = tf.keras.utils.get_file(
    "iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")

train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)

通过检查数据您可以发现有四列浮点型特征和一列 int32 型标签。

train.head()

对于每个数据集都分割出标签,模型将被训练来预测这些标签。

train_y = train.pop('Species')
test_y = test.pop('Species')

# 标签列现已从数据中删除
train.head()

Estimator 编程概述

现在您已经设定好了数据,您可以使用 Tensorflow Estimator 定义模型。Estimator 是从 tf.estimator.Estimator 中派生的任何类。Tensorflow提供了一组tf.estimator(例如,LinearRegressor)来实现常见的机器学习算法。此外,您可以编写您自己的自定义 Estimator。入门阶段我们建议使用预创建的 Estimator。

为了编写基于预创建的 Estimator 的 Tensorflow 项目,您必须完成以下工作:

  • 创建一个或多个输入函数
  • 定义模型的特征列
  • 实例化一个 Estimator,指定特征列和各种超参数。
  • 在 Estimator 对象上调用一个或多个方法,传递合适的输入函数以作为数据源。

我们来看看这些任务是如何在鸢尾花分类中实现的。

创建输入函数

您必须创建输入函数来提供用于训练、评估和预测的数据。

输入函数是一个返回 tf.data.Dataset 对象的函数,此对象会输出下列含两个元素的元组:

  • features——Python字典,其中:
    • 每个键都是特征名称
    • 每个值都是包含此特征所有值的数组
  • label 包含每个样本的标签的值的数组。

为了向您展示输入函数的格式,请查看下面这个简单的实现:

def input_evaluation_set():
    features = {'SepalLength': np.array([6.4, 5.0]),
                'SepalWidth':  np.array([2.8, 2.3]),
                'PetalLength': np.array([5.6, 3.3]),
                'PetalWidth':  np.array([2.2, 1.0])}
    labels = np.array([2, 1])
    return features, labels

您的输入函数可以以您喜欢的方式生成 features 字典与 label 列表。但是,我们建议使用 Tensorflow 的 Dataset API,该 API 可以用来解析各种类型的数据。

Dataset API 可以为您处理很多常见情况。例如,使用 Dataset API,您可以轻松地从大量文件中并行读取记录,并将它们合并为单个数据流。

为了简化此示例,我们将使用 pandas 加载数据,并利用此内存数据构建输入管道。

def input_fn(features, labels, training=True, batch_size=256):
    """An input function for training or evaluating"""
    # 将输入转换为数据集。
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # 如果在训练模式下混淆并重复数据。
    if training:
        dataset = dataset.shuffle(1000).repeat()
    
    return dataset.batch(batch_size)

定义特征列(feature columns)

特征列(feature columns)是一个对象,用于描述模型应该如何使用特征字典中的原始输入数据。当您构建一个 Estimator 模型的时候,您会向其传递一个特征列的列表,其中包含您希望模型使用的每个特征。tf.feature_column 模块提供了许多为模型表示数据的选项。

对于鸢尾花问题,4 个原始特征是数值,因此我们将构建一个特征列的列表,以告知 Estimator 模型将 4 个特征都表示为 32 位浮点值。故创建特征列的代码如下所示:

# 特征列描述了如何使用输入。
my_feature_columns = []
for key in train.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

特征列可能比上述示例复杂得多。您可以从指南获取更多关于特征列的信息。

我们已经介绍了如何使模型表示原始特征,现在您可以构建 Estimator 了。

实例化 Estimator

鸢尾花为题是一个经典的分类问题。幸运的是,Tensorflow 提供了几个预创建的 Estimator 分类器,其中包括:

对于鸢尾花问题,tf.estimator.LinearClassifier 似乎是最好的选择。您可以这样实例化该 Estimator:

# 构建一个拥有两个隐层,隐藏节点分别为 30 和 10 的深度神经网络。
classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    # 隐层所含结点数量分别为 30 和 10.
    hidden_units=[30, 10],
    # 模型必须从三个类别中做出选择。
    n_classes=3)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp4pxho8i6
INFO:tensorflow:Using config: {'_eval_distribute': None, '_task_id': 0, '_num_ps_replicas': 0, '_evaluation_master': '', '_model_dir': '/tmp/tmp4pxho8i6', '_experimental_max_worker_delay_secs': None, '_num_worker_replicas': 1, '_train_distribute': None, '_save_checkpoints_secs': 600, '_save_summary_steps': 100, '_experimental_distribute': None, '_device_fn': None, '_tf_random_seed': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_protocol': None, '_keep_checkpoint_every_n_hours': 10000, '_is_chief': True, '_log_step_count_steps': 100, '_master': '', '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f7430d61f28>, '_service': None, '_save_checkpoints_steps': None, '_keep_checkpoint_max': 5, '_task_type': 'worker', '_global_id_in_cluster': 0}

## 训练、评估和预测

我们已经有一个 Estimator 对象,现在可以调用方法来执行下列操作:

  • 训练模型。
  • 评估经过训练的模型。
  • 使用经过训练的模型进行预测。

训练模型

通过调用 Estimator 的 Train 方法来训练模型,如下所示:

# 训练模型。
classifier.train(
    input_fn=lambda: input_fn(train, train_y, training=True),
    steps=5000)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Layer dnn is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_estimator/python/estimator/head/base_head.py:550: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.cast` instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/keras/optimizer_v2/adagrad.py:108: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_estimator/python/estimator/model_fn.py:337: scalar (from tensorflow.python.framework.tensor_shape) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.TensorShape([]).
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/array_ops.py:1486: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp4pxho8i6/model.ckpt.
INFO:tensorflow:loss = 0.9816213, step = 0
INFO:tensorflow:global_step/sec: 255.604
INFO:tensorflow:loss = 0.8584726, step = 100 (0.393 sec)
INFO:tensorflow:global_step/sec: 324.335
INFO:tensorflow:loss = 0.79295284, step = 200 (0.308 sec)
INFO:tensorflow:global_step/sec: 322.827
INFO:tensorflow:loss = 0.73671246, step = 300 (0.310 sec)
INFO:tensorflow:global_step/sec: 325.968
INFO:tensorflow:loss = 0.7055951, step = 400 (0.307 sec)
INFO:tensorflow:global_step/sec: 320.857
INFO:tensorflow:loss = 0.67782617, step = 500 (0.312 sec)
INFO:tensorflow:global_step/sec: 323.031
INFO:tensorflow:loss = 0.656973, step = 600 (0.309 sec)
INFO:tensorflow:global_step/sec: 320.544
INFO:tensorflow:loss = 0.6344602, step = 700 (0.312 sec)
INFO:tensorflow:global_step/sec: 314.102
INFO:tensorflow:loss = 0.62019354, step = 800 (0.318 sec)
INFO:tensorflow:global_step/sec: 318.913
INFO:tensorflow:loss = 0.60511863, step = 900 (0.313 sec)
INFO:tensorflow:global_step/sec: 322.083
INFO:tensorflow:loss = 0.58901495, step = 1000 (0.310 sec)
INFO:tensorflow:global_step/sec: 348.311
INFO:tensorflow:loss = 0.5784603, step = 1100 (0.287 sec)
INFO:tensorflow:global_step/sec: 342.855
INFO:tensorflow:loss = 0.5622357, step = 1200 (0.292 sec)
INFO:tensorflow:global_step/sec: 343.381
INFO:tensorflow:loss = 0.55213696, step = 1300 (0.291 sec)
INFO:tensorflow:global_step/sec: 265.745
INFO:tensorflow:loss = 0.5365027, step = 1400 (0.378 sec)
INFO:tensorflow:global_step/sec: 302.678
INFO:tensorflow:loss = 0.5304105, step = 1500 (0.329 sec)
INFO:tensorflow:global_step/sec: 345.949
INFO:tensorflow:loss = 0.5283208, step = 1600 (0.289 sec)
INFO:tensorflow:global_step/sec: 350.7
INFO:tensorflow:loss = 0.502436, step = 1700 (0.285 sec)
INFO:tensorflow:global_step/sec: 340.195
INFO:tensorflow:loss = 0.507177, step = 1800 (0.294 sec)
INFO:tensorflow:global_step/sec: 345.932
INFO:tensorflow:loss = 0.4918634, step = 1900 (0.289 sec)
INFO:tensorflow:global_step/sec: 351.064
INFO:tensorflow:loss = 0.4812191, step = 2000 (0.285 sec)
INFO:tensorflow:global_step/sec: 352.017
INFO:tensorflow:loss = 0.4819128, step = 2100 (0.284 sec)
INFO:tensorflow:global_step/sec: 340.139
INFO:tensorflow:loss = 0.46524256, step = 2200 (0.294 sec)
INFO:tensorflow:global_step/sec: 350.892
INFO:tensorflow:loss = 0.47413245, step = 2300 (0.285 sec)
INFO:tensorflow:global_step/sec: 349.785
INFO:tensorflow:loss = 0.45543683, step = 2400 (0.286 sec)
INFO:tensorflow:global_step/sec: 361.24
INFO:tensorflow:loss = 0.44656873, step = 2500 (0.277 sec)
INFO:tensorflow:global_step/sec: 363.05
INFO:tensorflow:loss = 0.4391844, step = 2600 (0.275 sec)
INFO:tensorflow:global_step/sec: 357.382
INFO:tensorflow:loss = 0.43884304, step = 2700 (0.280 sec)
INFO:tensorflow:global_step/sec: 358.162
INFO:tensorflow:loss = 0.42847294, step = 2800 (0.279 sec)
INFO:tensorflow:global_step/sec: 358.435
INFO:tensorflow:loss = 0.4295861, step = 2900 (0.279 sec)
INFO:tensorflow:global_step/sec: 349.005
INFO:tensorflow:loss = 0.4188161, step = 3000 (0.286 sec)
INFO:tensorflow:global_step/sec: 344.775
INFO:tensorflow:loss = 0.40605238, step = 3100 (0.290 sec)
INFO:tensorflow:global_step/sec: 347.83
INFO:tensorflow:loss = 0.41728777, step = 3200 (0.287 sec)
INFO:tensorflow:global_step/sec: 349.773
INFO:tensorflow:loss = 0.41107363, step = 3300 (0.286 sec)
INFO:tensorflow:global_step/sec: 350.759
INFO:tensorflow:loss = 0.3960948, step = 3400 (0.285 sec)
INFO:tensorflow:global_step/sec: 354.544
INFO:tensorflow:loss = 0.40280253, step = 3500 (0.282 sec)
INFO:tensorflow:global_step/sec: 348.359
INFO:tensorflow:loss = 0.4021863, step = 3600 (0.287 sec)
INFO:tensorflow:global_step/sec: 348.412
INFO:tensorflow:loss = 0.38774556, step = 3700 (0.287 sec)
INFO:tensorflow:global_step/sec: 351.681
INFO:tensorflow:loss = 0.3878502, step = 3800 (0.284 sec)
INFO:tensorflow:global_step/sec: 350.65
INFO:tensorflow:loss = 0.37870076, step = 3900 (0.285 sec)
INFO:tensorflow:global_step/sec: 352.788
INFO:tensorflow:loss = 0.37454197, step = 4000 (0.284 sec)
INFO:tensorflow:global_step/sec: 357.071
INFO:tensorflow:loss = 0.36461756, step = 4100 (0.280 sec)
INFO:tensorflow:global_step/sec: 345.057
INFO:tensorflow:loss = 0.37231374, step = 4200 (0.290 sec)
INFO:tensorflow:global_step/sec: 350.562
INFO:tensorflow:loss = 0.36564863, step = 4300 (0.285 sec)
INFO:tensorflow:global_step/sec: 352.628
INFO:tensorflow:loss = 0.360517, step = 4400 (0.284 sec)
INFO:tensorflow:global_step/sec: 354.354
INFO:tensorflow:loss = 0.35699844, step = 4500 (0.282 sec)
INFO:tensorflow:global_step/sec: 351.717
INFO:tensorflow:loss = 0.36425847, step = 4600 (0.284 sec)
INFO:tensorflow:global_step/sec: 350.83
INFO:tensorflow:loss = 0.35635468, step = 4700 (0.285 sec)
INFO:tensorflow:global_step/sec: 347.447
INFO:tensorflow:loss = 0.3530446, step = 4800 (0.288 sec)
INFO:tensorflow:global_step/sec: 359.806
INFO:tensorflow:loss = 0.33091962, step = 4900 (0.278 sec)
INFO:tensorflow:Saving checkpoints for 5000 into /tmp/tmp4pxho8i6/model.ckpt.
INFO:tensorflow:Loss for final step: 0.3456031.

<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x7f7430d61b70>

注意将 input_fn 调用封装在 lambda 中以获取参数,同时提供不带参数的输入函数,如 Estimator 所预期的那样。step 参数告知该方法在训练多少步后停止训练。

评估经过训练的模型

现在模型已经经过训练,您可以获取一些关于模型性能的统计信息。代码块将在测试数据上对经过训练的模型的准确率(accuracy)进行评估:

eval_result = classifier.evaluate(
    input_fn=lambda: input_fn(test, test_y, training=False))

print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Layer dnn is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2019-09-28T05:05:48Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp4pxho8i6/model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2019-09-28-05:05:48
INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.93333334, average_loss = 0.39961192, global_step = 5000, loss = 0.39961192
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmp/tmp4pxho8i6/model.ckpt-5000

Test set accuracy: 0.933

与对 train 方法的调用不同,我们没有传递 steps 参数来进行评估。用于评估的 input_fn 只生成一个 epoch 的数据。

eval_result 字典亦包含 average_loss(每个样本的平均误差),loss(每个 mini-batch 的平均误差)与 Estimator 的 global_step(经历的训练迭代次数)值。

利用经过训练的模型进行预测(推理)

我们已经有一个经过训练的模型,可以生成准确的评估结果。我们现在可以使用经过训练的模型,根据一些无标签测量结果预测鸢尾花的品种。与训练和评估一样,我们使用单个函数调用进行预测:

# 由模型生成预测
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
    'SepalLength': [5.1, 5.9, 6.9],
    'SepalWidth': [3.3, 3.0, 3.1],
    'PetalLength': [1.7, 4.2, 5.4],
    'PetalWidth': [0.5, 1.5, 2.1],
}

def input_fn(features, batch_size=256):
    """An input function for prediction."""
    # 将输入转换为无标签数据集。
    return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)

predictions = classifier.predict(
    input_fn=lambda: input_fn(predict_x))

predict 方法返回一个 Python 可迭代对象,为每个样本生成一个预测结果字典。以下代码输出了一些预测及其概率:

for pred_dict, expec in zip(predictions, expected):
    class_id = pred_dict['class_ids'][0]
    probability = pred_dict['probabilities'][class_id]

    print('Prediction is "{}" ({:.1f}%), expected "{}"'.format(
        SPECIES[class_id], 100 * probability, expec))
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp4pxho8i6/model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
Prediction is "Setosa" (85.4%), expected "Setosa"
Prediction is "Versicolor" (55.6%), expected "Versicolor"
Prediction is "Virginica" (65.6%), expected "Virginica"