![]() |
![]() |
![]() |
![]() |
本文档介绍了 tf.estimator
,它是一种高级 TensorFlow API。Estimator 封装了以下操作:
- 训练
- 评估
- 预测
- 导出以供使用
您可以使用我们提供的预制 Estimator 或编写您自己的自定义 Estimator。所有 Estimator(无论是预制还是自定义)都是基于 tf.estimator.Estimator
类的类。
有关简单示例,请查看 Estimator 教程。有关 API 设计概述,请参阅白皮书。
优势
与 tf.keras.Model
类似,estimator
是模型级别的抽象。tf.estimator
提供了一些目前仍在为 tf.keras
开发中的功能。包括:
- 基于参数服务器的训练
- 完整的 TFX 集成
Estimator 功能
Estimator 提供了以下优势:
- 您可以在本地主机上或分布式多服务器环境中运行基于 Estimator 的模型,而无需更改模型。此外,您还可以在 CPU、GPU 或 TPU 上运行基于 Estimator 的模型,而无需重新编码模型。
- Estimator 提供了安全的分布式训练循环,可控制如何以及何时进行以下操作:
- 加载数据
- 处理异常
- 创建检查点文件并从故障中恢复
- 保存 TensorBoard 摘要
在用 Estimator 编写应用时,您必须将数据输入流水线与模型分离。这种分离简化了使用不同数据集进行的实验。
预制 Estimator
使用预制 Estimator,您能够在比基础 TensorFlow API 高很多的概念层面上工作。您无需再担心创建计算图或会话,因为 Estimator 会替您完成所有“基础工作”。此外,使用预制 Estimator,您只需改动较少代码就能试验不同的模型架构。例如,tf.estimator.DNNClassifier
是一个预制 Estimator 类,可基于密集的前馈神经网络对分类模型进行训练。
预制 Estimator 程序结构
依赖于预制 Estimator 的 TensorFlow 程序通常包括以下四个步骤:
1. 编写一个或多个数据集导入函数。
例如,您可以创建一个函数来导入训练集,创建另一个函数来导入测试集。每个数据集导入函数必须返回以下两个对象:
- 字典,其中键是特征名称,值是包含相应特征数据的张量(或 SparseTensor)
- 包含一个或多个标签的张量
例如,以下代码展示了输入函数的基本框架:
def input_fn(dataset): ... # manipulate dataset, extracting the feature dict and the label return feature_dict, label
有关详细信息,请参阅数据指南。
2. 定义特征列。
每个 tf.feature_column
标识了特征名称、特征类型,以及任何输入预处理。例如,以下代码段创建了三个包含整数或浮点数据的特征列。前两个特征列仅标识了特征的名称和类型。第三个特征列还指定了一个会被程序调用以缩放原始数据的 lambda:
# Define three numeric feature columns. population = tf.feature_column.numeric_column('population') crime_rate = tf.feature_column.numeric_column('crime_rate') median_education = tf.feature_column.numeric_column( 'median_education', normalizer_fn=lambda x: x - global_education_mean)
有关详细信息,请参阅特征列教程。
3. 实例化相关预制 Estimator。
例如,下面是对名为 LinearClassifier
的预制 Estimator 进行实例化的示例:
# Instantiate an estimator, passing the feature columns. estimator = tf.estimator.LinearClassifier( feature_columns=[population, crime_rate, median_education])
有关详细信息,请参阅线性分类器教程。
4. 调用训练、评估或推断方法。
例如,所有 Estimator 都会提供一个用于训练模型的 train
方法。
# `input_fn` is the function created in Step 1 estimator.train(input_fn=my_training_set, steps=2000)
您可以在下面看到与此相关的示例。
预制 Estimator 的优势
预制 Estimator 对最佳做法进行了编码,具有以下优势:
- 确定计算图不同部分的运行位置,以及在单台机器或集群上实施策略的最佳做法。
- 事件(摘要)编写和通用摘要的最佳做法。
如果不使用预制 Estimator,则您必须自己实现上述功能。
自定义 Estimator
每个 Estimator(无论预制还是自定义)的核心是其模型函数,这是一种为训练、评估和预测构建计算图的方法。当您使用预制 Estimator 时,已经有人为您实现了模型函数。当使用自定义 Estimator 时,您必须自己编写模型函数。
推荐工作流
- 假设存在一个合适的预制 Estimator,用它构建您的第一个模型,并将其结果作为基准。
- 使用此预制 Estimator 构建并测试您的整个流水线,包括数据的完整性和可靠性。
- 如果有其他合适的预制 Estimator,可通过运行实验确定哪个预制 Estimator 能够生成最佳结果。
- 如果可能,您可以通过构建自己的自定义 Estimator 进一步改进模型。
import tensorflow as tf
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
从 Keras 模型创建 Estimator
您可以使用 tf.keras.estimator.model_to_estimator
将现有的 Keras 模型转换为 Estimator。这样一来,您的 Keras 模型就可以利用 Estimator 的优势,例如分布式训练。
实例化 Keras MobileNet V2 模型并用训练中使用的优化器、损失和指标来编译模型:
keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False
estimator_model = tf.keras.Sequential([
keras_mobilenet_v2,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1)
])
# Compile the model
estimator_model.compile(
optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5 9412608/9406464 [==============================] - 0s 0us/step
从已编译的 Keras 模型创建 Estimator
。Keras 模型的初始模型状态会保留在已创建的 Estimator
中:
est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpf37wepof INFO:tensorflow:Using the Keras model provided. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py:220: set_learning_phase (from tensorflow.python.keras.backend) is deprecated and will be removed after 2020-10-11. Instructions for updating: Simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model. INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpf37wepof', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
您可以像对待任何其他 Estimator
一样对待派生的 Estimator
。
IMG_SIZE = 160 # All images will be resized to 160x160
def preprocess(image, label):
image = tf.cast(image, tf.float32)
image = (image/127.5) - 1
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label
def train_input_fn(batch_size):
data = tfds.load('cats_vs_dogs', as_supervised=True)
train_data = data['train']
train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
return train_data
要进行训练,可调用 Estimator 的训练函数:
est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=500)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0... Warning:absl:1738 images were corrupted and were skipped Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteUQEOHH/cats_vs_dogs-train.tfrecord Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpf37wepof/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpf37wepof/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting from: /tmp/tmpf37wepof/keras/keras_model.ckpt INFO:tensorflow:Warm-starting from: /tmp/tmpf37wepof/keras/keras_model.ckpt INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpf37wepof/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpf37wepof/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.68437946, step = 0 INFO:tensorflow:loss = 0.68437946, step = 0 INFO:tensorflow:global_step/sec: 22.0536 INFO:tensorflow:global_step/sec: 22.0536 INFO:tensorflow:loss = 0.70650136, step = 100 (4.536 sec) INFO:tensorflow:loss = 0.70650136, step = 100 (4.536 sec) INFO:tensorflow:global_step/sec: 23.6417 INFO:tensorflow:global_step/sec: 23.6417 INFO:tensorflow:loss = 0.6610012, step = 200 (4.230 sec) INFO:tensorflow:loss = 0.6610012, step = 200 (4.230 sec) INFO:tensorflow:global_step/sec: 23.82 INFO:tensorflow:global_step/sec: 23.82 INFO:tensorflow:loss = 0.6323833, step = 300 (4.198 sec) INFO:tensorflow:loss = 0.6323833, step = 300 (4.198 sec) INFO:tensorflow:global_step/sec: 23.7343 INFO:tensorflow:global_step/sec: 23.7343 INFO:tensorflow:loss = 0.6456127, step = 400 (4.213 sec) INFO:tensorflow:loss = 0.6456127, step = 400 (4.213 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500... INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmpf37wepof/model.ckpt. INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmpf37wepof/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500... INFO:tensorflow:Loss for final step: 0.63354445. INFO:tensorflow:Loss for final step: 0.63354445. <tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fe924063780>
同样,要进行评估,可调用 Estimator 的评估函数:
est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2020-09-22T19:13:33Z INFO:tensorflow:Starting evaluation at 2020-09-22T19:13:33Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpf37wepof/model.ckpt-500 INFO:tensorflow:Restoring parameters from /tmp/tmpf37wepof/model.ckpt-500 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 2.04706s INFO:tensorflow:Inference Time : 2.04706s INFO:tensorflow:Finished evaluation at 2020-09-22-19:13:35 INFO:tensorflow:Finished evaluation at 2020-09-22-19:13:35 INFO:tensorflow:Saving dict for global step 500: accuracy = 0.584375, global_step = 500, loss = 0.61907357 INFO:tensorflow:Saving dict for global step 500: accuracy = 0.584375, global_step = 500, loss = 0.61907357 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmp/tmpf37wepof/model.ckpt-500 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmp/tmpf37wepof/model.ckpt-500 {'accuracy': 0.584375, 'loss': 0.61907357, 'global_step': 500}
有关详细信息,请参阅 tf.keras.estimator.model_to_estimator
文档。