Migrasi panggilan balik SessionRunHook ke Keras

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Di TensorFlow 1, untuk menyesuaikan perilaku pelatihan, Anda menggunakan tf.estimator.SessionRunHook dengan tf.estimator.Estimator . Panduan ini menunjukkan cara bermigrasi dari SessionRunHook ke callback khusus TensorFlow 2 dengan tf.keras.callbacks.Callback API, yang berfungsi dengan Keras Model.fit untuk pelatihan (serta Model.evaluate dan Model.predict ). Anda akan mempelajari cara melakukannya dengan menerapkan SessionRunHook dan tugas Callback yang mengukur contoh per detik selama pelatihan.

Contoh callback adalah penyimpanan checkpoint ( tf.keras.callbacks.ModelCheckpoint ) dan penulisan ringkasan TensorBoard . Callback Keras adalah objek yang dipanggil pada titik yang berbeda selama pelatihan/evaluasi/prediksi dalam Keras Model.fit / Model.evaluate / Model.predict API bawaan. Anda dapat mempelajari lebih lanjut tentang panggilan balik di dokumen tf.keras.callbacks.Callback API, serta panduan Menulis panggilan balik Anda sendiri dan Pelatihan dan evaluasi dengan metode bawaan (bagian Menggunakan panggilan balik ).

Mempersiapkan

Mulailah dengan impor dan kumpulan data sederhana untuk tujuan demonstrasi:

import tensorflow as tf
import tensorflow.compat.v1 as tf1

import time
from datetime import datetime
from absl import flags
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]

TensorFlow 1: Buat SessionRunHook kustom dengan tf.estimator API

Contoh TensorFlow 1 berikut menunjukkan cara menyiapkan SessionRunHook kustom yang mengukur contoh per detik selama pelatihan. Setelah membuat hook ( LoggerHook ), berikan ke parameter hooks dari tf.estimator.Estimator.train .

def _input_fn():
  return tf1.data.Dataset.from_tensor_slices(
      (features, labels)).batch(1).repeat(100)

def _model_fn(features, labels, mode):
  logits = tf1.layers.Dense(1)(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
class LoggerHook(tf1.train.SessionRunHook):
  """Logs loss and runtime."""

  def begin(self):
    self._step = -1
    self._start_time = time.time()
    self.log_frequency = 10

  def before_run(self, run_context):
    self._step += 1

  def after_run(self, run_context, run_values):
    if self._step % self.log_frequency == 0:
      current_time = time.time()
      duration = current_time - self._start_time
      self._start_time = current_time
      examples_per_sec = self.log_frequency / duration
      print('Time:', datetime.now(), ', Step #:', self._step,
            ', Examples per second:', examples_per_sec)

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

# Begin training.
estimator.train(_input_fn, hooks=[LoggerHook()])
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpe4lxk_r8
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpe4lxk_r8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpe4lxk_r8/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
Time: 2021-10-26 01:34:53.978329 , Step #: 0 , Examples per second: 6.5659573368942015
INFO:tensorflow:loss = 0.272405, step = 0
Time: 2021-10-26 01:34:54.010834 , Step #: 10 , Examples per second: 307.6243353258279
Time: 2021-10-26 01:34:54.020112 , Step #: 20 , Examples per second: 1077.700865900974
Time: 2021-10-26 01:34:54.029483 , Step #: 30 , Examples per second: 1067.1171606665819
Time: 2021-10-26 01:34:54.039412 , Step #: 40 , Examples per second: 1007.1566814743667
Time: 2021-10-26 01:34:54.048087 , Step #: 50 , Examples per second: 1152.756355641061
Time: 2021-10-26 01:34:54.056877 , Step #: 60 , Examples per second: 1137.6234777184084
Time: 2021-10-26 01:34:54.066122 , Step #: 70 , Examples per second: 1081.6752630493088
Time: 2021-10-26 01:34:54.074645 , Step #: 80 , Examples per second: 1173.2647067050827
Time: 2021-10-26 01:34:54.083555 , Step #: 90 , Examples per second: 1122.3118912554853
INFO:tensorflow:global_step/sec: 866.456
Time: 2021-10-26 01:34:54.094488 , Step #: 100 , Examples per second: 914.6685275645499
INFO:tensorflow:loss = 0.00072448375, step = 100 (0.116 sec)
Time: 2021-10-26 01:34:54.104045 , Step #: 110 , Examples per second: 1046.3525009355121
Time: 2021-10-26 01:34:54.112493 , Step #: 120 , Examples per second: 1183.7949817956028
Time: 2021-10-26 01:34:54.120903 , Step #: 130 , Examples per second: 1189.0301913536498
Time: 2021-10-26 01:34:54.129681 , Step #: 140 , Examples per second: 1139.106488145352
Time: 2021-10-26 01:34:54.138138 , Step #: 150 , Examples per second: 1182.5933966786026
Time: 2021-10-26 01:34:54.146595 , Step #: 160 , Examples per second: 1182.4933746828306
Time: 2021-10-26 01:34:54.155248 , Step #: 170 , Examples per second: 1155.551147477753
Time: 2021-10-26 01:34:54.163869 , Step #: 180 , Examples per second: 1159.993362464738
Time: 2021-10-26 01:34:54.172881 , Step #: 190 , Examples per second: 1109.5455266917095
INFO:tensorflow:global_step/sec: 1129.39
Time: 2021-10-26 01:34:54.183226 , Step #: 200 , Examples per second: 966.6745027541543
INFO:tensorflow:loss = 0.004354417, step = 200 (0.088 sec)
Time: 2021-10-26 01:34:54.192698 , Step #: 210 , Examples per second: 1055.8082867643357
Time: 2021-10-26 01:34:54.201008 , Step #: 220 , Examples per second: 1203.288865937975
Time: 2021-10-26 01:34:54.209423 , Step #: 230 , Examples per second: 1188.3900946336487
Time: 2021-10-26 01:34:54.218621 , Step #: 240 , Examples per second: 1087.1987350631173
Time: 2021-10-26 01:34:54.227779 , Step #: 250 , Examples per second: 1091.9538673817397
Time: 2021-10-26 01:34:54.236563 , Step #: 260 , Examples per second: 1138.4571955919873
Time: 2021-10-26 01:34:54.244876 , Step #: 270 , Examples per second: 1202.9437577078613
Time: 2021-10-26 01:34:54.253524 , Step #: 280 , Examples per second: 1156.2838396647737
Time: 2021-10-26 01:34:54.262094 , Step #: 290 , Examples per second: 1166.8671581582973
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 300...
INFO:tensorflow:Saving checkpoints for 300 into /tmp/tmpe4lxk_r8/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 300...
INFO:tensorflow:Loss for final step: 0.0026133624.
<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f9750e2efd0>

TensorFlow 2: Buat panggilan balik Keras khusus untuk Model.fit

Di TensorFlow 2, saat Anda menggunakan Keras Model.fit (atau Model.evaluate ) bawaan untuk pelatihan/evaluasi, Anda dapat mengonfigurasi tf.keras.callbacks.Callback khusus, yang kemudian Anda teruskan ke parameter callbacks Model.fit (atau Model.evaluate ). (Pelajari lebih lanjut di panduan Menulis panggilan balik Anda sendiri .)

Pada contoh di bawah ini, Anda akan menulis tf.keras.callbacks.Callback khusus yang mencatat berbagai metrik—ini akan mengukur contoh per detik, yang harus sebanding dengan metrik dalam contoh SessionRunHook sebelumnya.

class CustomCallback(tf.keras.callbacks.Callback):

    def on_train_begin(self, logs = None):
      self._step = -1
      self._start_time = time.time()
      self.log_frequency = 10

    def on_train_batch_begin(self, batch, logs = None):
      self._step += 1

    def on_train_batch_end(self, batch, logs = None):
      if self._step % self.log_frequency == 0:
        current_time = time.time()
        duration = current_time - self._start_time
        self._start_time = current_time
        examples_per_sec = self.log_frequency / duration
        print('Time:', datetime.now(), ', Step #:', self._step,
              ', Examples per second:', examples_per_sec)

callback = CustomCallback()

dataset = tf.data.Dataset.from_tensor_slices(
    (features, labels)).batch(1).repeat(100)

model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)

model.compile(optimizer, "mse")

# Begin training.
result = model.fit(dataset, callbacks=[callback], verbose = 0)
# Provide the results of training metrics.
result.history
Time: 2021-10-26 01:34:54.545193 , Step #: 0 , Examples per second: 47.66297875435231
Time: 2021-10-26 01:34:54.558176 , Step #: 10 , Examples per second: 770.1198979123442
Time: 2021-10-26 01:34:54.570778 , Step #: 20 , Examples per second: 793.5191176192368
Time: 2021-10-26 01:34:54.583033 , Step #: 30 , Examples per second: 815.9807011400335
Time: 2021-10-26 01:34:54.595632 , Step #: 40 , Examples per second: 793.6993093007853
Time: 2021-10-26 01:34:54.607942 , Step #: 50 , Examples per second: 812.3458320421444
Time: 2021-10-26 01:34:54.619847 , Step #: 60 , Examples per second: 840.0368515922291
Time: 2021-10-26 01:34:54.632529 , Step #: 70 , Examples per second: 788.4919351806594
Time: 2021-10-26 01:34:54.646415 , Step #: 80 , Examples per second: 720.1881900444719
Time: 2021-10-26 01:34:54.659728 , Step #: 90 , Examples per second: 751.1154886194731
Time: 2021-10-26 01:34:54.672811 , Step #: 100 , Examples per second: 764.3517877318949
Time: 2021-10-26 01:34:54.685740 , Step #: 110 , Examples per second: 773.5000461041955
Time: 2021-10-26 01:34:54.698443 , Step #: 120 , Examples per second: 787.2192192192192
Time: 2021-10-26 01:34:54.711277 , Step #: 130 , Examples per second: 779.161449722279
Time: 2021-10-26 01:34:54.725101 , Step #: 140 , Examples per second: 723.355408388521
Time: 2021-10-26 01:34:54.738438 , Step #: 150 , Examples per second: 749.7861994994637
Time: 2021-10-26 01:34:54.752388 , Step #: 160 , Examples per second: 716.8280010937927
Time: 2021-10-26 01:34:54.765563 , Step #: 170 , Examples per second: 759.0538755270826
Time: 2021-10-26 01:34:54.779201 , Step #: 180 , Examples per second: 733.295569775167
Time: 2021-10-26 01:34:54.792040 , Step #: 190 , Examples per second: 778.8865366759517
Time: 2021-10-26 01:34:54.804998 , Step #: 200 , Examples per second: 771.664274938367
Time: 2021-10-26 01:34:54.818003 , Step #: 210 , Examples per second: 768.9762393663831
Time: 2021-10-26 01:34:54.831546 , Step #: 220 , Examples per second: 738.3428098649814
Time: 2021-10-26 01:34:54.845028 , Step #: 230 , Examples per second: 741.7245525924878
Time: 2021-10-26 01:34:54.858053 , Step #: 240 , Examples per second: 767.7375896910236
Time: 2021-10-26 01:34:54.871158 , Step #: 250 , Examples per second: 763.0585624101734
Time: 2021-10-26 01:34:54.883612 , Step #: 260 , Examples per second: 802.922010796738
Time: 2021-10-26 01:34:54.896472 , Step #: 270 , Examples per second: 777.6301981941895
Time: 2021-10-26 01:34:54.909765 , Step #: 280 , Examples per second: 752.2740561384629
Time: 2021-10-26 01:34:54.922856 , Step #: 290 , Examples per second: 763.8645759347284
{'loss': [0.33093082904815674]}

Langkah selanjutnya

Pelajari lebih lanjut tentang panggilan balik di:

Anda mungkin juga menemukan sumber daya terkait migrasi berikut ini berguna: