![]() |
![]() |
![]() |
![]() |
In TensorFlow 1, you use tf.estimator.LoggingTensorHook
to monitor and log tensors, while tf.estimator.StopAtStepHook
helps stop training at a specified step when training with tf.estimator.Estimator
. This notebook demonstrates how to migrate from these APIs to their equivalents in TensorFlow 2 using custom Keras callbacks (tf.keras.callbacks.Callback
) with Model.fit
.
Keras callbacks are objects that are called at different points during training/evaluation/prediction in the built-in Keras Model.fit
/Model.evaluate
/Model.predict
APIs. You can learn more about callbacks in the tf.keras.callbacks.Callback
API docs, as well as the Writing your own callbacks and Training and evaluation with the built-in methods (the Using callbacks section) guides. For migrating from SessionRunHook
in TensorFlow 1 to Keras callbacks in TensorFlow 2, check out the Migrate training with assisted logic guide.
Setup
Start with imports and a simple dataset for demonstration purposes:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
2022-12-14 03:08:34.557601: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 03:08:34.557701: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 03:08:34.557711: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
# Define an input function.
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
TensorFlow 1: Log tensors and stop training with tf.estimator APIs
In TensorFlow 1, you define various hooks to control the training behavior. Then, you pass these hooks to tf.estimator.EstimatorSpec
.
In the example below:
- To monitor/log tensors—for example, model weights or losses—you use
tf.estimator.LoggingTensorHook
(tf.train.LoggingTensorHook
is its alias). - To stop training at a specific step, you use
tf.estimator.StopAtStepHook
(tf.train.StopAtStepHook
is its alias).
def _model_fn(features, labels, mode):
dense = tf1.layers.Dense(1)
logits = dense(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
# Define the stop hook.
stop_hook = tf1.train.StopAtStepHook(num_steps=2)
# Access tensors to be logged by names.
kernel_name = tf.identity(dense.weights[0])
bias_name = tf.identity(dense.weights[1])
logging_weight_hook = tf1.train.LoggingTensorHook(
tensors=[kernel_name, bias_name],
every_n_iter=1)
# Log the training loss by the tensor object.
logging_loss_hook = tf1.train.LoggingTensorHook(
{'loss from LoggingTensorHook': loss},
every_n_secs=3)
# Pass all hooks to `EstimatorSpec`.
return tf1.estimator.EstimatorSpec(mode,
loss=loss,
train_op=train_op,
training_hooks=[stop_hook,
logging_weight_hook,
logging_loss_hook])
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
estimator.train(_input_fn)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpun_ddee_ INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpun_ddee_', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpun_ddee_/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 1.8706179, step = 0 INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-0.819447 ] [-0.16550565]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [0.] INFO:tensorflow:loss from LoggingTensorHook = 1.8706179 INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-0.7697778 ] [-0.11565349]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [0.04966919] (0.028 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmpfs/tmp/tmpun_ddee_/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Loss for final step: 5.1939325. <tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f265b4bea30>
TensorFlow 2: Log tensors and stop training with custom callbacks and Model.fit
In TensorFlow 2, when you use the built-in Keras Model.fit
(or Model.evaluate
) for training/evaluation, you can configure tensor monitoring and training stopping by defining custom Keras tf.keras.callbacks.Callback
s. Then, you pass them to the callbacks
parameter of Model.fit
(or Model.evaluate
). (Learn more in the Writing your own callbacks guide.)
In the example below:
- To recreate the functionalities of
StopAtStepHook
, define a custom callback (namedStopAtStepCallback
below) where you override theon_batch_end
method to stop training after a certain number of steps. - To recreate the
LoggingTensorHook
behavior, define a custom callback (LoggingTensorCallback
) where you record and output the logged tensors manually, since accessing to tensors by names is not supported. You can also implement the logging frequency inside the custom callback. The example below will print the weights every two steps. Other strategies like logging every N seconds are also possible.
class StopAtStepCallback(tf.keras.callbacks.Callback):
def __init__(self, stop_step=None):
super().__init__()
self._stop_step = stop_step
def on_batch_end(self, batch, logs=None):
if self.model.optimizer.iterations >= self._stop_step:
self.model.stop_training = True
print('\nstop training now')
class LoggingTensorCallback(tf.keras.callbacks.Callback):
def __init__(self, every_n_iter):
super().__init__()
self._every_n_iter = every_n_iter
self._log_count = every_n_iter
def on_batch_end(self, batch, logs=None):
if self._log_count > 0:
self._log_count -= 1
print("Logging Tensor Callback: dense/kernel:",
model.layers[0].weights[0])
print("Logging Tensor Callback: dense/bias:",
model.layers[0].weights[1])
print("Logging Tensor Callback loss:", logs["loss"])
else:
self._log_count -= self._every_n_iter
When finished, pass the new callbacks—StopAtStepCallback
and LoggingTensorCallback
—to the callbacks
parameter of Model.fit
:
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer, "mse")
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
model.fit(dataset, callbacks=[StopAtStepCallback(stop_step=2),
LoggingTensorCallback(every_n_iter=2)])
Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy= array([[1.3146261], [0.8792979]], dtype=float32)> Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([-0.04989691], dtype=float32)> Logging Tensor Callback loss: 6.043736457824707 1/3 [=========>....................] - ETA: 0s - loss: 6.0437 stop training now Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy= array([[1.2665784], [0.8320339]], dtype=float32)> Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([-0.09322532], dtype=float32)> Logging Tensor Callback loss: 12.170801162719727 3/3 [==============================] - 0s 4ms/step - loss: 12.1708 <keras.callbacks.History at 0x7f26585060d0>
Next steps
Learn more about callbacks in:
- API docs:
tf.keras.callbacks.Callback
- Guide: Writing your own callbacks
- Guide: Training and evaluation with the built-in methods (the Using callbacks section)
You may also find the following migration-related resources useful:
- The Early stopping migration guide:
tf.keras.callbacks.EarlyStopping
is a built-in early stopping callback - The TensorBoard migration guide: TensorBoard enables tracking and displaying metrics
- The Training with assisted logic migration guide: From
SessionRunHook
in TensorFlow 1 to Keras callbacks in TensorFlow 2