ترحيل LoggingTensorHook و StopAtStepHook إلى عمليات معاودة الاتصال Keras

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر

في TensorFlow 1 ، يمكنك استخدام tf.estimator.LoggingTensorHook لمراقبة وتسجيل الموترات ، بينما يساعد tf.estimator.StopAtStepHook في إيقاف التدريب عند خطوة محددة عند التدريب باستخدام tf.estimator.Estimator . يوضح هذا الكمبيوتر الدفتري كيفية الترحيل من واجهات برمجة التطبيقات هذه إلى ما يعادلها في TensorFlow 2 باستخدام عمليات رد نداء Keras المخصصة ( tf.keras.callbacks.Callback ) باستخدام Model.fit .

عمليات رد نداء Keras هي كائنات يتم استدعاؤها في نقاط مختلفة أثناء التدريب / التقييم / التنبؤ في واجهات برمجة تطبيقات Keras Model.fit / Model.evaluate / Model.predict Model.predict . يمكنك معرفة المزيد حول عمليات الاسترجاعات في مستندات tf.keras.callbacks.Callback API ، بالإضافة إلى أدلة كتابة عمليات الاسترجاعات الخاصة بك والتدريب والتقييم باستخدام الطرق المضمنة (قسم استخدام عمليات الاسترجاعات ). للترحيل من SessionRunHook في TensorFlow 1 إلى Keras Callbacks في TensorFlow 2 ، تحقق من تدريب الترحيل باستخدام دليل المنطق المساعد .

يثبت

ابدأ بالواردات ومجموعة بيانات بسيطة لأغراض التوضيح:

import tensorflow as tf
import tensorflow.compat.v1 as tf1
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]

# Define an input function.
def _input_fn():
  return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)

TensorFlow 1: سجل الموترات وأوقف التدريب باستخدام tf.estimator APIs

في TensorFlow 1 ، يمكنك تحديد العديد من الخطافات للتحكم في سلوك التدريب. ثم تقوم بتمرير هذه الخطافات إلى tf.estimator.EstimatorSpec .

في المثال أدناه:

  • لمراقبة / تسجيل الموترات - على سبيل المثال ، أوزان النموذج أو الخسائر - يمكنك استخدام tf.estimator.LoggingTensorHook ( tf.train.LoggingTensorHook هو اسمها المستعار).
  • لإيقاف التدريب في خطوة معينة ، يمكنك استخدام tf.estimator.StopAtStepHook ( tf.train.StopAtStepHook هو الاسم المستعار الخاص به).
def _model_fn(features, labels, mode):
  dense = tf1.layers.Dense(1)
  logits = dense(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())

  # Define the stop hook.
  stop_hook = tf1.train.StopAtStepHook(num_steps=2)

  # Access tensors to be logged by names.
  kernel_name = tf.identity(dense.weights[0])
  bias_name = tf.identity(dense.weights[1])
  logging_weight_hook = tf1.train.LoggingTensorHook(
      tensors=[kernel_name, bias_name],
      every_n_iter=1)
  # Log the training loss by the tensor object.
  logging_loss_hook = tf1.train.LoggingTensorHook(
      {'loss from LoggingTensorHook': loss},
      every_n_secs=3)

  # Pass all hooks to `EstimatorSpec`.
  return tf1.estimator.EstimatorSpec(mode,
                                     loss=loss,
                                     train_op=train_op,
                                     training_hooks=[stop_hook,
                                                     logging_weight_hook,
                                                     logging_loss_hook])

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
estimator.train(_input_fn)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp3q__3yt7
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp3q__3yt7', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp3q__3yt7/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.025395721, step = 0
INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-1.0769143]
 [ 1.0241832]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [0.]
INFO:tensorflow:loss from LoggingTensorHook = 0.025395721
INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-1.1124082]
 [ 0.9824805]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [-0.03549388] (0.026 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2...
INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmp3q__3yt7/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2...
INFO:tensorflow:Loss for final step: 0.09248222.
<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f05ec414d10>

TensorFlow 2: سجل الموترات وأوقف التدريب باستخدام عمليات الاسترجاعات المخصصة و Model.fit

في TensorFlow 2 ، عند استخدام Keras Model.fit (أو Model.evaluate ) للتدريب / التقييم ، يمكنك تكوين مراقبة الموتر وإيقاف التدريب عن طريق تحديد Keras tf.keras.callbacks.Callback . ثم تقوم بتمريرها إلى معامل عمليات callbacks Model.fit (أو Model.evaluate ). (تعرف على المزيد في دليل كتابة عمليات الاسترجاعات الخاصة بك .)

في المثال أدناه:

  • لإعادة إنشاء وظائف StopAtStepHook ، حدد رد اتصال مخصص (يسمى StopAtStepCallback أدناه) حيث يمكنك تجاوز طريقة on_batch_end لإيقاف التدريب بعد عدد معين من الخطوات.
  • لإعادة إنشاء سلوك LoggingTensorHook ، حدد رد اتصال مخصص ( LoggingTensorCallback ) حيث تقوم بتسجيل وإخراج الموترات المسجلة يدويًا ، نظرًا لأن الوصول إلى الموترات حسب الأسماء غير مدعوم. يمكنك أيضًا تنفيذ تردد التسجيل داخل رد الاتصال المخصص. المثال أدناه سيطبع الأوزان كل خطوتين. من الممكن أيضًا استخدام استراتيجيات أخرى مثل تسجيل كل N ثانية.
class StopAtStepCallback(tf.keras.callbacks.Callback):
  def __init__(self, stop_step=None):
    super().__init__()
    self._stop_step = stop_step

  def on_batch_end(self, batch, logs=None):
    if self.model.optimizer.iterations >= self._stop_step:
      self.model.stop_training = True
      print('\nstop training now')

class LoggingTensorCallback(tf.keras.callbacks.Callback):
  def __init__(self, every_n_iter):
      super().__init__()
      self._every_n_iter = every_n_iter
      self._log_count = every_n_iter

  def on_batch_end(self, batch, logs=None):
    if self._log_count > 0:
      self._log_count -= 1
      print("Logging Tensor Callback: dense/kernel:",
            model.layers[0].weights[0])
      print("Logging Tensor Callback: dense/bias:",
            model.layers[0].weights[1])
      print("Logging Tensor Callback loss:", logs["loss"])
    else:
      self._log_count -= self._every_n_iter

عند الانتهاء ، قم بتمرير عمليات الاسترجاعات الجديدة - StopAtStepCallback و LoggingTensorCallback - إلى معلمة callbacks الخاصة بـ Model.fit :

dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer, "mse")

# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
model.fit(dataset, callbacks=[StopAtStepCallback(stop_step=2),
                              LoggingTensorCallback(every_n_iter=2)])
1/3 [=========>....................] - ETA: 0s - loss: 3.2473Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-0.27049014],
       [-0.73790836]], dtype=float32)>
Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.04980864], dtype=float32)>
Logging Tensor Callback loss: 3.2473244667053223

stop training now
Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-0.22285421],
       [-0.6911988 ]], dtype=float32)>
Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.09196297], dtype=float32)>
Logging Tensor Callback loss: 5.644947052001953
3/3 [==============================] - 0s 4ms/step - loss: 5.6449
<keras.callbacks.History at 0x7f053022be90>

الخطوات التالية

تعرف على المزيد حول عمليات رد الاتصال في:

قد تجد أيضًا الموارد التالية المتعلقة بالترحيل مفيدة: