|  在 TensorFlow.org 上查看 |  在 Google Colab 运行 |  在 Github 上查看源代码 |  下载笔记本 | 
在 TensorFlow 1 中,要自定义训练的行为,可以使用 tf.estimator.SessionRunHook 和 tf.estimator.Estimator。本指南演示了如何使用 tf.keras.callbacks.Callback API 从 SessionRunHook 迁移到 TensorFlow 2 的自定义回调,此 API 与 Keras Model.fit 一起用于训练(以及 Model.evaluate 和 Model.predict)。您将通过实现 SessionRunHook 和 Callback 任务来学习如何做到这一点,此任务会在训练期间测量每秒的样本数。
回调的示例为检查点保存 (tf.keras.callbacks.ModelCheckpoint) 和 TensorBoard 摘要编写。Keras 回调是在内置 Keras Model.fit/Model.evaluate/Model.predict API 中的训练/评估/预测期间的不同点调用的对象。可以在 tf.keras.callbacks.Callback API 文档以及编写自己的回调和使用内置方法进行训练和评估(使用回调部分)指南中详细了解回调。
安装
从导入和用于演示目的的简单数据集开始:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import time
from datetime import datetime
from absl import flags
2022-12-14 21:05:28.476156: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 21:05:28.476264: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 21:05:28.476274: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]
TensorFlow 1:使用 tf.estimator API 创建自定义 SessionRunHook
下面的 TensorFlow 1 示例展示了如何设置自定义 SessionRunHook 以在训练期间测量每秒的样本数。创建钩子 (LoggerHook) 后,将其传递给 tf.estimator.Estimator.train 的 hooks 参数。
def _input_fn():
  return tf1.data.Dataset.from_tensor_slices(
      (features, labels)).batch(1).repeat(100)
def _model_fn(features, labels, mode):
  logits = tf1.layers.Dense(1)(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
class LoggerHook(tf1.train.SessionRunHook):
  """Logs loss and runtime."""
  def begin(self):
    self._step = -1
    self._start_time = time.time()
    self.log_frequency = 10
  def before_run(self, run_context):
    self._step += 1
  def after_run(self, run_context, run_values):
    if self._step % self.log_frequency == 0:
      current_time = time.time()
      duration = current_time - self._start_time
      self._start_time = current_time
      examples_per_sec = self.log_frequency / duration
      print('Time:', datetime.now(), ', Step #:', self._step,
            ', Examples per second:', examples_per_sec)
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
# Begin training.
estimator.train(_input_fn, hooks=[LoggerHook()])
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpx2ok2ea3
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpx2ok2ea3', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpx2ok2ea3/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
Time: 2022-12-14 21:05:33.564674 , Step #: 0 , Examples per second: 2.6388003375708364
INFO:tensorflow:loss = 0.2890536, step = 0
Time: 2022-12-14 21:05:33.597332 , Step #: 10 , Examples per second: 306.1669854154197
Time: 2022-12-14 21:05:33.604084 , Step #: 20 , Examples per second: 1480.9349622201821
Time: 2022-12-14 21:05:33.610736 , Step #: 30 , Examples per second: 1503.173135505143
Time: 2022-12-14 21:05:33.617525 , Step #: 40 , Examples per second: 1473.0294303575192
Time: 2022-12-14 21:05:33.624239 , Step #: 50 , Examples per second: 1489.4545454545455
Time: 2022-12-14 21:05:33.631684 , Step #: 60 , Examples per second: 1343.2088644078653
Time: 2022-12-14 21:05:33.638755 , Step #: 70 , Examples per second: 1414.1281186783547
Time: 2022-12-14 21:05:33.645996 , Step #: 80 , Examples per second: 1381.068159367797
Time: 2022-12-14 21:05:33.653013 , Step #: 90 , Examples per second: 1425.0829029627616
INFO:tensorflow:global_step/sec: 1037.69
Time: 2022-12-14 21:05:33.661947 , Step #: 100 , Examples per second: 1119.3466947772945
INFO:tensorflow:loss = 2.0069905e-05, step = 100 (0.097 sec)
Time: 2022-12-14 21:05:33.669810 , Step #: 110 , Examples per second: 1271.8105461050973
Time: 2022-12-14 21:05:33.676815 , Step #: 120 , Examples per second: 1427.4108358290225
Time: 2022-12-14 21:05:33.683562 , Step #: 130 , Examples per second: 1482.1909675595448
Time: 2022-12-14 21:05:33.690362 , Step #: 140 , Examples per second: 1470.5504522824485
Time: 2022-12-14 21:05:33.697251 , Step #: 150 , Examples per second: 1451.768370772905
Time: 2022-12-14 21:05:33.703805 , Step #: 160 , Examples per second: 1525.5342983923765
Time: 2022-12-14 21:05:33.710547 , Step #: 170 , Examples per second: 1483.2917211868303
Time: 2022-12-14 21:05:33.717187 , Step #: 180 , Examples per second: 1506.141913243321
Time: 2022-12-14 21:05:33.723858 , Step #: 190 , Examples per second: 1498.9293117003788
INFO:tensorflow:global_step/sec: 1418.02
Time: 2022-12-14 21:05:33.732304 , Step #: 200 , Examples per second: 1183.995483415667
INFO:tensorflow:loss = 7.567363e-05, step = 200 (0.070 sec)
Time: 2022-12-14 21:05:33.740206 , Step #: 210 , Examples per second: 1265.555488504013
Time: 2022-12-14 21:05:33.747279 , Step #: 220 , Examples per second: 1413.7467978967238
Time: 2022-12-14 21:05:33.754154 , Step #: 230 , Examples per second: 1454.6382742595547
Time: 2022-12-14 21:05:33.760685 , Step #: 240 , Examples per second: 1531.0472713998904
Time: 2022-12-14 21:05:33.767590 , Step #: 250 , Examples per second: 1448.2093778054002
Time: 2022-12-14 21:05:33.774400 , Step #: 260 , Examples per second: 1468.5424179825636
Time: 2022-12-14 21:05:33.781041 , Step #: 270 , Examples per second: 1505.655311052877
Time: 2022-12-14 21:05:33.787675 , Step #: 280 , Examples per second: 1507.603608784731
Time: 2022-12-14 21:05:33.794272 , Step #: 290 , Examples per second: 1515.776083264067
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 300...
INFO:tensorflow:Saving checkpoints for 300 into /tmpfs/tmp/tmpx2ok2ea3/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 300...
INFO:tensorflow:Loss for final step: 2.7116595e-05.
<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7fae99b2ac40>
TensorFlow 2:为 Model.fit 创建自定义 Keras 回调
在 TensorFlow 2 中,当您使用内置 Keras Model.fit(或 Model.evaluate)进行训练/评估时,可以配置自定义 tf.keras.callbacks.Callback,然后将其传递给 Model.fit(或 Model.evaluate)的 callbacks 参数。(在编写自己的回调指南中了解详情。)
在下面的示例中,您将编写一个自定义 tf.keras.callbacks.Callback 来记录各种指标 – 它将测量每秒的样本数,这应该与前面的 SessionRunHook 示例中的指标相当。
class CustomCallback(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs = None):
      self._step = -1
      self._start_time = time.time()
      self.log_frequency = 10
    def on_train_batch_begin(self, batch, logs = None):
      self._step += 1
    def on_train_batch_end(self, batch, logs = None):
      if self._step % self.log_frequency == 0:
        current_time = time.time()
        duration = current_time - self._start_time
        self._start_time = current_time
        examples_per_sec = self.log_frequency / duration
        print('Time:', datetime.now(), ', Step #:', self._step,
              ', Examples per second:', examples_per_sec)
callback = CustomCallback()
dataset = tf.data.Dataset.from_tensor_slices(
    (features, labels)).batch(1).repeat(100)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer, "mse")
# Begin training.
result = model.fit(dataset, callbacks=[callback], verbose = 0)
# Provide the results of training metrics.
result.history
Time: 2022-12-14 21:05:34.715803 , Step #: 0 , Examples per second: 20.32853849177442
Time: 2022-12-14 21:05:34.735514 , Step #: 10 , Examples per second: 507.25071655762093
Time: 2022-12-14 21:05:34.752205 , Step #: 20 , Examples per second: 599.0921426632958
Time: 2022-12-14 21:05:34.769600 , Step #: 30 , Examples per second: 574.9008319969297
Time: 2022-12-14 21:05:34.786129 , Step #: 40 , Examples per second: 604.9680518094359
Time: 2022-12-14 21:05:34.802761 , Step #: 50 , Examples per second: 601.2477064220184
Time: 2022-12-14 21:05:34.819602 , Step #: 60 , Examples per second: 593.8164880438322
Time: 2022-12-14 21:05:34.836180 , Step #: 70 , Examples per second: 603.210561891476
Time: 2022-12-14 21:05:34.852972 , Step #: 80 , Examples per second: 595.5026763023014
Time: 2022-12-14 21:05:34.869650 , Step #: 90 , Examples per second: 599.5802956228379
Time: 2022-12-14 21:05:34.886414 , Step #: 100 , Examples per second: 596.5359616560709
Time: 2022-12-14 21:05:34.902936 , Step #: 110 , Examples per second: 605.2386724386724
Time: 2022-12-14 21:05:34.919523 , Step #: 120 , Examples per second: 602.889751329596
Time: 2022-12-14 21:05:34.936131 , Step #: 130 , Examples per second: 602.1108240022969
Time: 2022-12-14 21:05:34.952220 , Step #: 140 , Examples per second: 621.5533261214267
Time: 2022-12-14 21:05:34.968538 , Step #: 150 , Examples per second: 612.7991818248229
Time: 2022-12-14 21:05:34.984989 , Step #: 160 , Examples per second: 607.8877648627497
Time: 2022-12-14 21:05:35.001109 , Step #: 170 , Examples per second: 620.3490504644146
Time: 2022-12-14 21:05:35.016755 , Step #: 180 , Examples per second: 639.1222990887758
Time: 2022-12-14 21:05:35.032922 , Step #: 190 , Examples per second: 618.5376788084353
Time: 2022-12-14 21:05:35.049564 , Step #: 200 , Examples per second: 600.8945430581224
Time: 2022-12-14 21:05:35.066073 , Step #: 210 , Examples per second: 605.7281497313846
Time: 2022-12-14 21:05:35.082339 , Step #: 220 , Examples per second: 614.8022631995544
Time: 2022-12-14 21:05:35.098710 , Step #: 230 , Examples per second: 610.8269012320508
Time: 2022-12-14 21:05:35.115086 , Step #: 240 , Examples per second: 610.6490405613953
Time: 2022-12-14 21:05:35.131756 , Step #: 250 , Examples per second: 599.8890128436168
Time: 2022-12-14 21:05:35.148034 , Step #: 260 , Examples per second: 614.325009154156
Time: 2022-12-14 21:05:35.164035 , Step #: 270 , Examples per second: 624.9614828721708
Time: 2022-12-14 21:05:35.180192 , Step #: 280 , Examples per second: 618.9027593330383
Time: 2022-12-14 21:05:35.195872 , Step #: 290 , Examples per second: 637.7811568639377
{'loss': [0.20529791712760925]}
后续步骤
通过下列方式详细了解回调:
- API 文档:tf.keras.callbacks.Callback
- 指南:编写自己的回调
- 指南:使用内置方法进行训练和评估(使用回调部分)
此外,您可能还会发现下列与迁移相关的资源十分有用:
- 提前停止迁移指南:tf.keras.callbacks.EarlyStopping是一个内置的提前停止回调
- TensorBoard 迁移指南:TensorBoard 支持跟踪和显示指标
- LoggingTensorHook 和 StopAtStepHook 到 Keras 回调迁移指南