在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
本文档介绍了 tf.estimator
,它是一种高级 TensorFlow API。Estimator 封装了以下操作:
- 训练
- 评估
- 预测
- 导出以供使用
您可以使用我们提供的预制 Estimator 或编写您自己的自定义 Estimator。所有 Estimator(无论是预制还是自定义)都是基于 tf.estimator.Estimator
类的类。
有关简单示例,请查看 Estimator 教程。有关 API 设计概述,请参阅白皮书。
优势
与 tf.keras.Model
类似,estimator
是模型级别的抽象。tf.estimator
提供了一些目前仍在为 tf.keras
开发中的功能。包括:
- 基于参数服务器的训练
- 完整的 TFX 集成
Estimator 功能
Estimator 提供了以下优势:
- 您可以在本地主机上或分布式多服务器环境中运行基于 Estimator 的模型,而无需更改模型。此外,您还可以在 CPU、GPU 或 TPU 上运行基于 Estimator 的模型,而无需重新编码模型。
- Estimator 提供了安全的分布式训练循环,可控制如何以及何时进行以下操作:
- 加载数据
- 处理异常
- 创建检查点文件并从故障中恢复
- 保存 TensorBoard 摘要
在用 Estimator 编写应用时,您必须将数据输入流水线与模型分离。这种分离简化了使用不同数据集进行的实验。
预制 Estimator
使用预制 Estimator,您能够在比基础 TensorFlow API 高很多的概念层面上工作。您无需再担心创建计算图或会话,因为 Estimator 会替您完成所有“基础工作”。此外,使用预制 Estimator,您只需改动较少代码就能试验不同的模型架构。例如,tf.estimator.DNNClassifier
是一个预制 Estimator 类,可基于密集的前馈神经网络对分类模型进行训练。
预制 Estimator 程序结构
依赖于预制 Estimator 的 TensorFlow 程序通常包括以下四个步骤:
1. 编写一个或多个数据集导入函数。
例如,您可以创建一个函数来导入训练集,创建另一个函数来导入测试集。每个数据集导入函数必须返回以下两个对象:
- 字典,其中键是特征名称,值是包含相应特征数据的张量(或 SparseTensor)
- 包含一个或多个标签的张量
例如,以下代码展示了输入函数的基本框架:
def input_fn(dataset): ... # manipulate dataset, extracting the feature dict and the label return feature_dict, label
有关详细信息,请参阅数据指南。
2. 定义特征列。
每个 tf.feature_column
标识了特征名称、特征类型,以及任何输入预处理。例如,以下代码段创建了三个包含整数或浮点数据的特征列。前两个特征列仅标识了特征的名称和类型。第三个特征列还指定了一个会被程序调用以缩放原始数据的 lambda:
# Define three numeric feature columns. population = tf.feature_column.numeric_column('population') crime_rate = tf.feature_column.numeric_column('crime_rate') median_education = tf.feature_column.numeric_column( 'median_education', normalizer_fn=lambda x: x - global_education_mean)
有关详细信息,请参阅特征列教程。
3. 实例化相关预制 Estimator。
例如,下面是对名为 LinearClassifier
的预制 Estimator 进行实例化的示例:
# Instantiate an estimator, passing the feature columns. estimator = tf.estimator.LinearClassifier( feature_columns=[population, crime_rate, median_education])
有关详细信息,请参阅线性分类器教程。
4. 调用训练、评估或推断方法。
例如,所有 Estimator 都会提供一个用于训练模型的 train
方法。
# `input_fn` is the function created in Step 1 estimator.train(input_fn=my_training_set, steps=2000)
您可以在下面看到与此相关的示例。
预制 Estimator 的优势
预制 Estimator 对最佳做法进行了编码,具有以下优势:
- 确定计算图不同部分的运行位置,以及在单台机器或集群上实施策略的最佳做法。
- 事件(摘要)编写和通用摘要的最佳做法。
如果不使用预制 Estimator,则您必须自己实现上述功能。
自定义 Estimator
每个 Estimator(无论预制还是自定义)的核心是其模型函数,这是一种为训练、评估和预测构建计算图的方法。当您使用预制 Estimator 时,已经有人为您实现了模型函数。当使用自定义 Estimator 时,您必须自己编写模型函数。
推荐工作流
- 假设存在一个合适的预制 Estimator,用它构建您的第一个模型,并将其结果作为基准。
- 使用此预制 Estimator 构建并测试您的整个流水线,包括数据的完整性和可靠性。
- 如果有其他合适的预制 Estimator,可通过运行实验确定哪个预制 Estimator 能够生成最佳结果。
- 如果可能,您可以通过构建自己的自定义 Estimator 进一步改进模型。
import tensorflow as tf
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
从 Keras 模型创建 Estimator
您可以使用 tf.keras.estimator.model_to_estimator
将现有的 Keras 模型转换为 Estimator。这样一来,您的 Keras 模型就可以利用 Estimator 的优势,例如分布式训练。
实例化 Keras MobileNet V2 模型并用训练中使用的优化器、损失和指标来编译模型:
keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False
estimator_model = tf.keras.Sequential([
keras_mobilenet_v2,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1)
])
# Compile the model
estimator_model.compile(
optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5 9406464/9406464 [==============================] - 0s 0us/step
从已编译的 Keras 模型创建 Estimator
。Keras 模型的初始模型状态会保留在已创建的 Estimator
中:
est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmprx_rffua INFO:tensorflow:Using the Keras model provided. /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/backend.py:450: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model. warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and ' INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmprx_rffua', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
您可以像对待任何其他 Estimator
一样对待派生的 Estimator
。
IMG_SIZE = 160 # All images will be resized to 160x160
def preprocess(image, label):
image = tf.cast(image, tf.float32)
image = (image/127.5) - 1
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label
def train_input_fn(batch_size):
data = tfds.load('cats_vs_dogs', as_supervised=True)
train_data = data['train']
train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
return train_data
要进行训练,可调用 Estimator 的训练函数:
est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=500)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmpfs/tmp/tmprx_rffua/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmpfs/tmp/tmprx_rffua/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting from: /tmpfs/tmp/tmprx_rffua/keras/keras_model.ckpt INFO:tensorflow:Warm-starting from: /tmpfs/tmp/tmprx_rffua/keras/keras_model.ckpt INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmprx_rffua/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmprx_rffua/model.ckpt. INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.data-00001-of-00002 INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.data-00001-of-00002 INFO:tensorflow:0 INFO:tensorflow:0 INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.index INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.index INFO:tensorflow:0 INFO:tensorflow:0 INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.meta INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.meta INFO:tensorflow:1100 INFO:tensorflow:1100 INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.data-00000-of-00002 INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-0.data-00000-of-00002 INFO:tensorflow:10100 INFO:tensorflow:10100 INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.70208776, step = 0 INFO:tensorflow:loss = 0.70208776, step = 0 INFO:tensorflow:global_step/sec: 48.5969 INFO:tensorflow:global_step/sec: 48.5969 INFO:tensorflow:loss = 0.65885234, step = 100 (2.059 sec) INFO:tensorflow:loss = 0.65885234, step = 100 (2.059 sec) INFO:tensorflow:global_step/sec: 55.4617 INFO:tensorflow:global_step/sec: 55.4617 INFO:tensorflow:loss = 0.7369333, step = 200 (1.803 sec) INFO:tensorflow:loss = 0.7369333, step = 200 (1.803 sec) Corrupt JPEG data: 99 extraneous bytes before marker 0xd9 Warning: unknown JFIF revision number 0.00 Corrupt JPEG data: 396 extraneous bytes before marker 0xd9 INFO:tensorflow:global_step/sec: 54.9075 Corrupt JPEG data: 162 extraneous bytes before marker 0xd9 INFO:tensorflow:global_step/sec: 54.9075 INFO:tensorflow:loss = 0.6614636, step = 300 (1.821 sec) INFO:tensorflow:loss = 0.6614636, step = 300 (1.821 sec) Corrupt JPEG data: 252 extraneous bytes before marker 0xd9 Corrupt JPEG data: 65 extraneous bytes before marker 0xd9 Corrupt JPEG data: 1403 extraneous bytes before marker 0xd9 INFO:tensorflow:global_step/sec: 54.9522 INFO:tensorflow:global_step/sec: 54.9522 INFO:tensorflow:loss = 0.5951805, step = 400 (1.820 sec) INFO:tensorflow:loss = 0.5951805, step = 400 (1.820 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500... INFO:tensorflow:Saving checkpoints for 500 into /tmpfs/tmp/tmprx_rffua/model.ckpt. INFO:tensorflow:Saving checkpoints for 500 into /tmpfs/tmp/tmprx_rffua/model.ckpt. INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.data-00000-of-00002 INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.data-00000-of-00002 INFO:tensorflow:9000 INFO:tensorflow:9000 INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.index INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.index INFO:tensorflow:9000 INFO:tensorflow:9000 INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.meta INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.meta INFO:tensorflow:10100 INFO:tensorflow:10100 INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.data-00001-of-00002 INFO:tensorflow:/tmpfs/tmp/tmprx_rffua/model.ckpt-500.data-00001-of-00002 INFO:tensorflow:10100 INFO:tensorflow:10100 INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500... INFO:tensorflow:Loss for final step: 0.61634666. INFO:tensorflow:Loss for final step: 0.61634666. <tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7ff5b00aac70>
同样,要进行评估,可调用 Estimator 的评估函数:
est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/engine/training_v1.py:2045: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. updates = self.state_updates INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-06-03T18:33:47 INFO:tensorflow:Starting evaluation at 2022-06-03T18:33:47 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmprx_rffua/model.ckpt-500 INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmprx_rffua/model.ckpt-500 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 2.40658s INFO:tensorflow:Inference Time : 2.40658s INFO:tensorflow:Finished evaluation at 2022-06-03-18:33:49 INFO:tensorflow:Finished evaluation at 2022-06-03-18:33:49 INFO:tensorflow:Saving dict for global step 500: accuracy = 0.596875, global_step = 500, loss = 0.64985955 INFO:tensorflow:Saving dict for global step 500: accuracy = 0.596875, global_step = 500, loss = 0.64985955 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmpfs/tmp/tmprx_rffua/model.ckpt-500 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmpfs/tmp/tmprx_rffua/model.ckpt-500 {'accuracy': 0.596875, 'loss': 0.64985955, 'global_step': 500}
有关详细信息,请参阅 tf.keras.estimator.model_to_estimator
文档。