Estimator

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

本文档介绍了 tf.estimator,它是一种高级 TensorFlow API。Estimator 封装了以下操作:

  • 训练
  • 评估
  • 预测
  • 导出以供使用

您可以使用我们提供的预制 Estimator 或编写您自己的自定义 Estimator。所有 Estimator(无论是预制还是自定义)都是基于 tf.estimator.Estimator 类的类。

有关简单示例,请查看 Estimator 教程。有关 API 设计概述,请参阅白皮书

优势

tf.keras.Model 类似,estimator 是模型级别的抽象。tf.estimator 提供了一些目前仍在为 tf.keras 开发中的功能。包括:

  • 基于参数服务器的训练
  • 完整的 TFX 集成

Estimator 功能

Estimator 提供了以下优势:

  • 您可以在本地主机上或分布式多服务器环境中运行基于 Estimator 的模型,而无需更改模型。此外,您还可以在 CPU、GPU 或 TPU 上运行基于 Estimator 的模型,而无需重新编码模型。
  • Estimator 提供了安全的分布式训练循环,可控制如何以及何时进行以下操作:
    • 加载数据
    • 处理异常
    • 创建检查点文件并从故障中恢复
    • 保存 TensorBoard 摘要

在用 Estimator 编写应用时,您必须将数据输入流水线与模型分离。这种分离简化了使用不同数据集进行的实验。

预制 Estimator

使用预制 Estimator,您能够在比基础 TensorFlow API 高很多的概念层面上工作。您无需再担心创建计算图或会话,因为 Estimator 会替您完成所有“基础工作”。此外,使用预制 Estimator,您只需改动较少代码就能试验不同的模型架构。例如,tf.estimator.DNNClassifier 是一个预制 Estimator 类,可基于密集的前馈神经网络对分类模型进行训练。

预制 Estimator 程序结构

依赖于预制 Estimator 的 TensorFlow 程序通常包括以下四个步骤:

1. 编写一个或多个数据集导入函数。

例如,您可以创建一个函数来导入训练集,创建另一个函数来导入测试集。每个数据集导入函数必须返回以下两个对象:

  • 字典,其中键是特征名称,值是包含相应特征数据的张量(或 SparseTensor)
  • 包含一个或多个标签的张量

例如,以下代码展示了输入函数的基本框架:

def input_fn(dataset):     ...  # manipulate dataset, extracting the feature dict and the label     return feature_dict, label

有关详细信息,请参阅数据指南

2. 定义特征列。

每个 tf.feature_column 标识了特征名称、特征类型,以及任何输入预处理。例如,以下代码段创建了三个包含整数或浮点数据的特征列。前两个特征列仅标识了特征的名称和类型。第三个特征列还指定了一个会被程序调用以缩放原始数据的 lambda:

# Define three numeric feature columns. population = tf.feature_column.numeric_column('population') crime_rate = tf.feature_column.numeric_column('crime_rate') median_education = tf.feature_column.numeric_column(   'median_education',   normalizer_fn=lambda x: x - global_education_mean)

有关详细信息,请参阅特征列教程

3. 实例化相关预制 Estimator。

例如,下面是对名为 LinearClassifier 的预制 Estimator 进行实例化的示例:

# Instantiate an estimator, passing the feature columns. estimator = tf.estimator.LinearClassifier(   feature_columns=[population, crime_rate, median_education])

有关详细信息,请参阅线性分类器教程

4. 调用训练、评估或推断方法。

例如,所有 Estimator 都会提供一个用于训练模型的 train 方法。

# `input_fn` is the function created in Step 1 estimator.train(input_fn=my_training_set, steps=2000)

您可以在下面看到与此相关的示例。

预制 Estimator 的优势

预制 Estimator 对最佳做法进行了编码,具有以下优势:

  • 确定计算图不同部分的运行位置,以及在单台机器或集群上实施策略的最佳做法。
  • 事件(摘要)编写和通用摘要的最佳做法。

如果不使用预制 Estimator,则您必须自己实现上述功能。

自定义 Estimator

每个 Estimator(无论预制还是自定义)的核心是其模型函数,这是一种为训练、评估和预测构建计算图的方法。当您使用预制 Estimator 时,已经有人为您实现了模型函数。当使用自定义 Estimator 时,您必须自己编写模型函数。

推荐工作流

  1. 假设存在一个合适的预制 Estimator,用它构建您的第一个模型,并将其结果作为基准。
  2. 使用此预制 Estimator 构建并测试您的整个流水线,包括数据的完整性和可靠性。
  3. 如果有其他合适的预制 Estimator,可通过运行实验确定哪个预制 Estimator 能够生成最佳结果。
  4. 如果可能,您可以通过构建自己的自定义 Estimator 进一步改进模型。
import tensorflow as tf
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

从 Keras 模型创建 Estimator

您可以使用 tf.keras.estimator.model_to_estimator 将现有的 Keras 模型转换为 Estimator。这样一来,您的 Keras 模型就可以利用 Estimator 的优势,例如分布式训练。

实例化 Keras MobileNet V2 模型并用训练中使用的优化器、损失和指标来编译模型:

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9406464/9406464 [==============================] - 0s 0us/step

从已编译的 Keras 模型创建 Estimator。Keras 模型的初始模型状态会保留在已创建的 Estimator中:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmprx_rffua
INFO:tensorflow:Using the Keras model provided.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/backend.py:450: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmprx_rffua', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

您可以像对待任何其他 Estimator 一样对待派生的 Estimator

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

要进行训练,可调用 Estimator 的训练函数:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=500)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmpfs/tmp/tmprx_rffua/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmpfs/tmp/tmprx_rffua/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmpfs/tmp/tmprx_rffua/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmpfs/tmp/tmprx_rffua/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmprx_rffua/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmprx_rffua/model.ckpt.
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.data-00001-of-00002
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.data-00001-of-00002
INFO:tensorflow:0
INFO:tensorflow:0
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.index
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.index
INFO:tensorflow:0
INFO:tensorflow:0
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.meta
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.meta
INFO:tensorflow:1100
INFO:tensorflow:1100
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.data-00000-of-00002
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.data-00000-of-00002
INFO:tensorflow:10100
INFO:tensorflow:10100
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.70208776, step = 0
INFO:tensorflow:loss = 0.70208776, step = 0
INFO:tensorflow:global_step/sec: 48.5969
INFO:tensorflow:global_step/sec: 48.5969
INFO:tensorflow:loss = 0.65885234, step = 100 (2.059 sec)
INFO:tensorflow:loss = 0.65885234, step = 100 (2.059 sec)
INFO:tensorflow:global_step/sec: 55.4617
INFO:tensorflow:global_step/sec: 55.4617
INFO:tensorflow:loss = 0.7369333, step = 200 (1.803 sec)
INFO:tensorflow:loss = 0.7369333, step = 200 (1.803 sec)
Corrupt JPEG data: 99 extraneous bytes before marker 0xd9
Warning: unknown JFIF revision number 0.00
Corrupt JPEG data: 396 extraneous bytes before marker 0xd9
INFO:tensorflow:global_step/sec: 54.9075
Corrupt JPEG data: 162 extraneous bytes before marker 0xd9
INFO:tensorflow:global_step/sec: 54.9075
INFO:tensorflow:loss = 0.6614636, step = 300 (1.821 sec)
INFO:tensorflow:loss = 0.6614636, step = 300 (1.821 sec)
Corrupt JPEG data: 252 extraneous bytes before marker 0xd9
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
Corrupt JPEG data: 1403 extraneous bytes before marker 0xd9
INFO:tensorflow:global_step/sec: 54.9522
INFO:tensorflow:global_step/sec: 54.9522
INFO:tensorflow:loss = 0.5951805, step = 400 (1.820 sec)
INFO:tensorflow:loss = 0.5951805, step = 400 (1.820 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500...
INFO:tensorflow:Saving checkpoints for 500 into /tmpfs/tmp/tmprx_rffua/model.ckpt.
INFO:tensorflow:Saving checkpoints for 500 into /tmpfs/tmp/tmprx_rffua/model.ckpt.
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.data-00000-of-00002
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.data-00000-of-00002
INFO:tensorflow:9000
INFO:tensorflow:9000
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.index
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.index
INFO:tensorflow:9000
INFO:tensorflow:9000
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.meta
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.meta
INFO:tensorflow:10100
INFO:tensorflow:10100
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.data-00001-of-00002
INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.data-00001-of-00002
INFO:tensorflow:10100
INFO:tensorflow:10100
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500...
INFO:tensorflow:Loss for final step: 0.61634666.
INFO:tensorflow:Loss for final step: 0.61634666.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7ff5b00aac70>

同样,要进行评估,可调用 Estimator 的评估函数:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/engine/training_v1.py:2045: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  updates = self.state_updates
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-06-03T18:33:47
INFO:tensorflow:Starting evaluation at 2022-06-03T18:33:47
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmprx_rffua/model.ckpt-500
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmprx_rffua/model.ckpt-500
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 2.40658s
INFO:tensorflow:Inference Time : 2.40658s
INFO:tensorflow:Finished evaluation at 2022-06-03-18:33:49
INFO:tensorflow:Finished evaluation at 2022-06-03-18:33:49
INFO:tensorflow:Saving dict for global step 500: accuracy = 0.596875, global_step = 500, loss = 0.64985955
INFO:tensorflow:Saving dict for global step 500: accuracy = 0.596875, global_step = 500, loss = 0.64985955
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmpfs/tmp/tmprx_rffua/model.ckpt-500
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmpfs/tmp/tmprx_rffua/model.ckpt-500
{'accuracy': 0.596875, 'loss': 0.64985955, 'global_step': 500}

有关详细信息,请参阅 tf.keras.estimator.model_to_estimator 文档。