Migrate SessionRunHook to Keras callbacks

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

In TensorFlow 1, to customize the behavior of training, you use tf.estimator.SessionRunHook with tf.estimator.Estimator. This guide demonstrates how to migrate from SessionRunHook to TensorFlow 2's custom callbacks with the tf.keras.callbacks.Callback API, which works with Keras Model.fit for training (as well as Model.evaluate and Model.predict). You will learn how to do this by implementing a SessionRunHook and a Callback task that measures examples per second during training.

Examples of callbacks are checkpoint saving (tf.keras.callbacks.ModelCheckpoint) and TensorBoard summary writing. Keras callbacks are objects that are called at different points during training/evaluation/prediction in the built-in Keras Model.fit/Model.evaluate/Model.predict APIs. You can learn more about callbacks in the tf.keras.callbacks.Callback API docs, as well as the Writing your own callbacks and Training and evaluation with the built-in methods (the Using callbacks section) guides.

Setup

Start with imports and a simple dataset for demonstration purposes:

import tensorflow as tf
import tensorflow.compat.v1 as tf1

import time
from datetime import datetime
from absl import flags
2022-12-14 03:29:50.223888: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 03:29:50.223983: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 03:29:50.223993: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]

TensorFlow 1: Create a custom SessionRunHook with tf.estimator APIs

The following TensorFlow 1 examples show how to set up a custom SessionRunHook that measures examples per second during training. After creating the hook (LoggerHook), pass it to the hooks parameter of tf.estimator.Estimator.train.

def _input_fn():
  return tf1.data.Dataset.from_tensor_slices(
      (features, labels)).batch(1).repeat(100)

def _model_fn(features, labels, mode):
  logits = tf1.layers.Dense(1)(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
class LoggerHook(tf1.train.SessionRunHook):
  """Logs loss and runtime."""

  def begin(self):
    self._step = -1
    self._start_time = time.time()
    self.log_frequency = 10

  def before_run(self, run_context):
    self._step += 1

  def after_run(self, run_context, run_values):
    if self._step % self.log_frequency == 0:
      current_time = time.time()
      duration = current_time - self._start_time
      self._start_time = current_time
      examples_per_sec = self.log_frequency / duration
      print('Time:', datetime.now(), ', Step #:', self._step,
            ', Examples per second:', examples_per_sec)

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

# Begin training.
estimator.train(_input_fn, hooks=[LoggerHook()])
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpsur2vo6l
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpsur2vo6l', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpsur2vo6l/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
Time: 2022-12-14 03:29:55.384194 , Step #: 0 , Examples per second: 2.6260658149330043
INFO:tensorflow:loss = 3.0540311, step = 0
Time: 2022-12-14 03:29:55.416911 , Step #: 10 , Examples per second: 305.5736558356404
Time: 2022-12-14 03:29:55.423947 , Step #: 20 , Examples per second: 1421.0272394633419
Time: 2022-12-14 03:29:55.430885 , Step #: 30 , Examples per second: 1441.3415807560139
Time: 2022-12-14 03:29:55.437421 , Step #: 40 , Examples per second: 1529.8745258243362
Time: 2022-12-14 03:29:55.444351 , Step #: 50 , Examples per second: 1443.1268923754474
Time: 2022-12-14 03:29:55.451639 , Step #: 60 , Examples per second: 1372.0775949491315
Time: 2022-12-14 03:29:55.458439 , Step #: 70 , Examples per second: 1470.5504522824485
Time: 2022-12-14 03:29:55.465428 , Step #: 80 , Examples per second: 1430.9658489986693
Time: 2022-12-14 03:29:55.472126 , Step #: 90 , Examples per second: 1493.0067988466878
INFO:tensorflow:global_step/sec: 1043.82
Time: 2022-12-14 03:29:55.480772 , Step #: 100 , Examples per second: 1156.5707982903625
INFO:tensorflow:loss = 0.009998276, step = 100 (0.096 sec)
Time: 2022-12-14 03:29:55.488702 , Step #: 110 , Examples per second: 1261.0276299570066
Time: 2022-12-14 03:29:55.495626 , Step #: 120 , Examples per second: 1444.1703680749233
Time: 2022-12-14 03:29:55.502328 , Step #: 130 , Examples per second: 1492.2100469617192
Time: 2022-12-14 03:29:55.509197 , Step #: 140 , Examples per second: 1455.748993474941
Time: 2022-12-14 03:29:55.516025 , Step #: 150 , Examples per second: 1464.491620111732
Time: 2022-12-14 03:29:55.522747 , Step #: 160 , Examples per second: 1487.7639046538025
Time: 2022-12-14 03:29:55.529257 , Step #: 170 , Examples per second: 1535.981250228879
Time: 2022-12-14 03:29:55.535874 , Step #: 180 , Examples per second: 1511.406435804115
Time: 2022-12-14 03:29:55.542583 , Step #: 190 , Examples per second: 1490.460182651647
INFO:tensorflow:global_step/sec: 1424.74
Time: 2022-12-14 03:29:55.550799 , Step #: 200 , Examples per second: 1217.045527087021
INFO:tensorflow:loss = 0.0025755945, step = 200 (0.070 sec)
Time: 2022-12-14 03:29:55.558161 , Step #: 210 , Examples per second: 1358.435030444358
Time: 2022-12-14 03:29:55.564965 , Step #: 220 , Examples per second: 1469.622985283812
Time: 2022-12-14 03:29:55.571853 , Step #: 230 , Examples per second: 1451.8688774273944
Time: 2022-12-14 03:29:55.578271 , Step #: 240 , Examples per second: 1558.0624071322436
Time: 2022-12-14 03:29:55.585110 , Step #: 250 , Examples per second: 1462.2961335982986
Time: 2022-12-14 03:29:55.592251 , Step #: 260 , Examples per second: 1400.2483808506377
Time: 2022-12-14 03:29:55.598872 , Step #: 270 , Examples per second: 1510.4811293575337
Time: 2022-12-14 03:29:55.605597 , Step #: 280 , Examples per second: 1486.920022688599
Time: 2022-12-14 03:29:55.612161 , Step #: 290 , Examples per second: 1523.5394115510353
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 300...
INFO:tensorflow:Saving checkpoints for 300 into /tmpfs/tmp/tmpsur2vo6l/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 300...
INFO:tensorflow:Loss for final step: 0.0015507338.
<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f330df83700>

TensorFlow 2: Create a custom Keras callback for Model.fit

In TensorFlow 2, when you use the built-in Keras Model.fit (or Model.evaluate) for training/evaluation, you can configure a custom tf.keras.callbacks.Callback, which you then pass to the callbacks parameter of Model.fit (or Model.evaluate). (Learn more in the Writing your own callbacks guide.)

In the example below, you will write a custom tf.keras.callbacks.Callback that logs various metrics—it will measure examples per second, which should be comparable to the metrics in the previous SessionRunHook example.

class CustomCallback(tf.keras.callbacks.Callback):

    def on_train_begin(self, logs = None):
      self._step = -1
      self._start_time = time.time()
      self.log_frequency = 10

    def on_train_batch_begin(self, batch, logs = None):
      self._step += 1

    def on_train_batch_end(self, batch, logs = None):
      if self._step % self.log_frequency == 0:
        current_time = time.time()
        duration = current_time - self._start_time
        self._start_time = current_time
        examples_per_sec = self.log_frequency / duration
        print('Time:', datetime.now(), ', Step #:', self._step,
              ', Examples per second:', examples_per_sec)

callback = CustomCallback()

dataset = tf.data.Dataset.from_tensor_slices(
    (features, labels)).batch(1).repeat(100)

model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)

model.compile(optimizer, "mse")

# Begin training.
result = model.fit(dataset, callbacks=[callback], verbose = 0)
# Provide the results of training metrics.
result.history
Time: 2022-12-14 03:29:56.540425 , Step #: 0 , Examples per second: 20.293572619500353
Time: 2022-12-14 03:29:56.559273 , Step #: 10 , Examples per second: 530.467951636566
Time: 2022-12-14 03:29:56.576364 , Step #: 20 , Examples per second: 585.0693969786159
Time: 2022-12-14 03:29:56.594158 , Step #: 30 , Examples per second: 561.9528926284198
Time: 2022-12-14 03:29:56.611010 , Step #: 40 , Examples per second: 593.4132227897172
Time: 2022-12-14 03:29:56.628626 , Step #: 50 , Examples per second: 567.6723601223506
Time: 2022-12-14 03:29:56.645569 , Step #: 60 , Examples per second: 590.21501744906
Time: 2022-12-14 03:29:56.662407 , Step #: 70 , Examples per second: 593.8921613049388
Time: 2022-12-14 03:29:56.679832 , Step #: 80 , Examples per second: 573.8704028021016
Time: 2022-12-14 03:29:56.697418 , Step #: 90 , Examples per second: 568.6497918898033
Time: 2022-12-14 03:29:56.714880 , Step #: 100 , Examples per second: 572.6715910487295
Time: 2022-12-14 03:29:56.732110 , Step #: 110 , Examples per second: 580.3578198724247
Time: 2022-12-14 03:29:56.750562 , Step #: 120 , Examples per second: 541.96276052771
Time: 2022-12-14 03:29:56.768724 , Step #: 130 , Examples per second: 550.5997873373854
Time: 2022-12-14 03:29:56.786106 , Step #: 140 , Examples per second: 575.3187753758367
Time: 2022-12-14 03:29:56.803982 , Step #: 150 , Examples per second: 559.3822435016871
Time: 2022-12-14 03:29:56.820872 , Step #: 160 , Examples per second: 592.089668120668
Time: 2022-12-14 03:29:56.837365 , Step #: 170 , Examples per second: 606.2972867488688
Time: 2022-12-14 03:29:56.853743 , Step #: 180 , Examples per second: 610.5868137947098
Time: 2022-12-14 03:29:56.870270 , Step #: 190 , Examples per second: 605.0815084105138
Time: 2022-12-14 03:29:56.886886 , Step #: 200 , Examples per second: 601.8084511084008
Time: 2022-12-14 03:29:56.903426 , Step #: 210 , Examples per second: 604.6192212884347
Time: 2022-12-14 03:29:56.919841 , Step #: 220 , Examples per second: 609.1767849880904
Time: 2022-12-14 03:29:56.936514 , Step #: 230 , Examples per second: 599.7946488581275
Time: 2022-12-14 03:29:56.953708 , Step #: 240 , Examples per second: 581.58099807263
Time: 2022-12-14 03:29:56.970554 , Step #: 250 , Examples per second: 593.6147869284006
Time: 2022-12-14 03:29:56.986865 , Step #: 260 , Examples per second: 613.0768556143479
Time: 2022-12-14 03:29:57.003039 , Step #: 270 , Examples per second: 618.273264641283
Time: 2022-12-14 03:29:57.019090 , Step #: 280 , Examples per second: 623.0398098633393
Time: 2022-12-14 03:29:57.035298 , Step #: 290 , Examples per second: 616.9727280750787
{'loss': [0.07423444092273712]}

Next steps

Learn more about callbacks in:

You may also find the following migration-related resources useful: