在 TensorFlow.org 上查看 | 在 Google Colab 运行 | 在 Github 上查看源代码 | 下载笔记本 |
在 TensorFlow 1 中,可以使用 tf.estimator.LoggingTensorHook
监视和记录张量,而 tf.estimator.StopAtStepHook
则在使用 tf.estimator.Estimator
进行训练时有助于在指定步骤停止训练。本笔记本演示了如何使用带有 Model.fit
的自定义 Keras 回调 (tf.keras.callbacks.Callback
) 从这些 API 迁移到 TensorFlow 2 中的对应项。
Keras 回调是在内置 Keras Model.fit
/Model.evaluate
/Model.predict
API 中的训练/评估/预测期间的不同点调用的对象。可以在 tf.keras.callbacks.Callback
API 文档以及编写自己的回调和使用内置方法进行训练和评估(使用回调 部分)指南中详细了解回调。要从 TensorFlow 1 中的 SessionRunHook
迁移到 TensorFlow 2 中的 Keras 回调,请查看迁移使用辅助逻辑的训练指南。
安装
从导入和用于演示目的的简单数据集开始:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
2022-12-14 20:26:46.140102: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 20:26:46.140199: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 20:26:46.140208: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
# Define an input function.
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
TensorFlow 1:使用 tf.estimator API 记录张量和停止训练
在 TensorFlow 1 中,定义各种钩子来控制训练行为。随后,将这些钩子传递给 tf.estimator.EstimatorSpec
。
在下面的示例中:
- 要监视/记录张量(例如模型权重或损失),可以使用
tf.estimator.LoggingTensorHook
(tf.train.LoggingTensorHook
是它的别名)。 - 要在特定步骤停止训练,请使用
tf.estimator.StopAtStepHook
(tf.train.StopAtStepHook
是它的别名)。
def _model_fn(features, labels, mode):
dense = tf1.layers.Dense(1)
logits = dense(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
# Define the stop hook.
stop_hook = tf1.train.StopAtStepHook(num_steps=2)
# Access tensors to be logged by names.
kernel_name = tf.identity(dense.weights[0])
bias_name = tf.identity(dense.weights[1])
logging_weight_hook = tf1.train.LoggingTensorHook(
tensors=[kernel_name, bias_name],
every_n_iter=1)
# Log the training loss by the tensor object.
logging_loss_hook = tf1.train.LoggingTensorHook(
{'loss from LoggingTensorHook': loss},
every_n_secs=3)
# Pass all hooks to `EstimatorSpec`.
return tf1.estimator.EstimatorSpec(mode,
loss=loss,
train_op=train_op,
training_hooks=[stop_hook,
logging_weight_hook,
logging_loss_hook])
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
estimator.train(_input_fn)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp9nj_75jg INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp9nj_75jg', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp9nj_75jg/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 12.718544, step = 0 INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-1.1990657] [-1.3781608]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [0.] INFO:tensorflow:loss from LoggingTensorHook = 12.718544 INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-1.1491147] [-1.3281827]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [0.04995093] (0.027 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmpfs/tmp/tmp9nj_75jg/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Loss for final step: 36.829548. <tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f226c14eeb0>
TensorFlow 2:使用自定义回调和 Model.fit 记录张量和停止训练
在 TensorFlow 2 中,当您使用内置 Keras Model.fit
(或 Model.evaluate
)进行训练/评估时,可以通过定义自定义 Keras tf.keras.callbacks.Callback
来配置张量监视和训练停止。随后,将它们传递给 Model.fit
(或 Model.evaluate
)的 callbacks
参数。(在编写自己的回调指南中了解详情。)
在下面的示例中:
- 要重新创建
StopAtStepHook
的功能,请定义一个自定义回调(下称StopAtStepCallback
),可以在其中重写on_batch_end
方法以在一定数量的步骤后停止训练。 - 要重新创建
LoggingTensorHook
行为,请定义一个自定义回调 (LoggingTensorCallback
),可以在其中手动记录和输出记录的张量,因为不支持按名称访问张量。此外,您还可以在自定义回调中实现记录频率。下面的示例将每两步打印一次权重。每 N 秒记录一次之类的其他策略也是可行的。
class StopAtStepCallback(tf.keras.callbacks.Callback):
def __init__(self, stop_step=None):
super().__init__()
self._stop_step = stop_step
def on_batch_end(self, batch, logs=None):
if self.model.optimizer.iterations >= self._stop_step:
self.model.stop_training = True
print('\nstop training now')
class LoggingTensorCallback(tf.keras.callbacks.Callback):
def __init__(self, every_n_iter):
super().__init__()
self._every_n_iter = every_n_iter
self._log_count = every_n_iter
def on_batch_end(self, batch, logs=None):
if self._log_count > 0:
self._log_count -= 1
print("Logging Tensor Callback: dense/kernel:",
model.layers[0].weights[0])
print("Logging Tensor Callback: dense/bias:",
model.layers[0].weights[1])
print("Logging Tensor Callback loss:", logs["loss"])
else:
self._log_count -= self._every_n_iter
完成后,将新回调(StopAtStepCallback
和 LoggingTensorCallback
)传递给 Model.fit
的 callbacks
参数:
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer, "mse")
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
model.fit(dataset, callbacks=[StopAtStepCallback(stop_step=2),
LoggingTensorCallback(every_n_iter=2)])
Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy= array([[1.3215477 ], [0.41005942]], dtype=float32)> Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([-0.04979974], dtype=float32)> Logging Tensor Callback loss: 3.102187156677246 1/3 [=========>....................] - ETA: 0s - loss: 3.1022 stop training now Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy= array([[1.2734439], [0.3627134]], dtype=float32)> Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([-0.09329326], dtype=float32)> Logging Tensor Callback loss: 6.4134416580200195 3/3 [==============================] - 0s 4ms/step - loss: 6.4134 <keras.callbacks.History at 0x7f2150509220>
后续步骤
通过以下方式详细了解回调:
- API 文档:
tf.keras.callbacks.Callback
- 指南:编写自己的回调
- 指南:使用内置方法进行训练和评估(使用回调部分)
此外,您可能还会发现下列与迁移相关的资源十分有用:
- 提前停止迁移指南:
tf.keras.callbacks.EarlyStopping
是一个内置的提前停止回调 - TensorBoard 迁移指南:TensorBoard 支持跟踪和显示指标
- 使用辅助逻辑进行训练迁移指南:从 TensorFlow 1 中的
SessionRunHook
到 TensorFlow 2 中的 Keras 回调