عرض على TensorFlow.org | تشغيل في Google Colab | عرض المصدر على جيثب | تحميل دفتر |
يشير التسامح مع الخطأ إلى آلية لحفظ حالات الكائنات القابلة للتتبع بشكل دوري ، مثل المعلمات والنماذج. يمكّنك هذا من استعادتها في حالة فشل البرنامج / الجهاز أثناء التدريب.
يوضح هذا الدليل أولاً كيفية إضافة التسامح مع الخطأ إلى التدريب باستخدام tf.estimator.Estimator
في TensorFlow 1 عن طريق تحديد التوفير المتري باستخدام tf.estimator.RunConfig
. بعد ذلك ، ستتعلم كيفية تنفيذ التسامح مع الخطأ للتدريب في Tensorflow 2 بطريقتين:
- إذا كنت تستخدم واجهة برمجة تطبيقات
Model.fit
، فيمكنك تمرير ردtf.keras.callbacks.BackupAndRestore
إليها. - إذا كنت تستخدم حلقة تدريب مخصصة (مع
tf.GradientTape
) ، فيمكنك حفظ نقاط التحقق بشكل تعسفي باستخدامtf.train.Checkpoint
وtf.train.CheckpointManager
APIs.
كلتا الطريقتين ستعملان على نسخ احتياطي واستعادة حالات التدريب في ملفات نقاط التفتيش .
يثبت
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
TensorFlow 1: احفظ نقاط التفتيش باستخدام tf.estimator.RunConfig
في TensorFlow 1 ، يمكنك تكوين tf.estimator
لحفظ نقاط التحقق في كل خطوة عن طريق تكوين tf.estimator.RunConfig
.
في هذا المثال ، ابدأ بكتابة خطاف يلقي خطأ بشكل مصطنع أثناء نقطة التحقق الخامسة:
class InterruptHook(tf1.train.SessionRunHook):
# A hook for artificially interrupting training.
def begin(self):
self._step = -1
def before_run(self, run_context):
self._step += 1
def after_run(self, run_context, run_values):
if self._step == 5:
raise RuntimeError('Interruption')
بعد ذلك ، قم بتكوين tf.estimator.Estimator
لحفظ كل نقطة تفتيش واستخدام مجموعة بيانات MNIST:
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead. WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead.
ابدأ تدريب النموذج. سيظهر استثناء مصطنع بواسطة الخطاف الذي حددته سابقًا.
try:
classifier.train(input_fn=train_input_fn,
hooks=[InterruptHook()],
max_steps=10)
except Exception as e:
print(f'{type(e).__name__}:{e}')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:65: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py:914: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:loss = 118.92192, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3... INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4... INFO:tensorflow:Saving checkpoints for 4 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5... INFO:tensorflow:Saving checkpoints for 5 into /tmp/tmpv15yxr9g/model.ckpt. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1054: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to delete files with this prefix. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... RuntimeError:Interruption
أعد بناء tf.estimator.Estimator
باستخدام آخر نقطة تفتيش محفوظة وتابع التدريب:
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
classifier.train(input_fn=train_input_fn,
max_steps = 10)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpv15yxr9g/model.ckpt-6 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1161: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file utilities to get mtimes. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7... INFO:tensorflow:Saving checkpoints for 7 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7... INFO:tensorflow:loss = 105.44863, step = 6 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8... INFO:tensorflow:Saving checkpoints for 8 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9... INFO:tensorflow:Saving checkpoints for 9 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 100.47882. <tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7fcfe8165150>
TensorFlow 2: النسخ الاحتياطي والاستعادة باستخدام رد الاتصال و Model.fit
في TensorFlow 2 ، إذا كنت تستخدم Keras Model.fit
API للتدريب ، فيمكنك توفير رد الاتصال tf.keras.callbacks.BackupAndRestore
لإضافة وظيفة تحمل الخطأ.
للمساعدة في توضيح ذلك ، دعنا نبدأ أولاً بتحديد فئة رد نداء تسبب بشكل مصطنع في حدوث خطأ أثناء نقطة التحقق الخامسة:
class InterruptingCallback(tf.keras.callbacks.Callback):
# A callback for artificially interrupting training.
def on_epoch_end(self, epoch, log=None):
if epoch == 4:
raise RuntimeError('Interruption')
بعد ذلك ، قم بتعريف نموذج Keras وإنشاء مثيل له ، وحدد وظيفة الخسارة ، واستدعاء Model.compile
، وقم بإعداد رد tf.keras.callbacks.BackupAndRestore
الذي سيحفظ نقاط التفتيش في دليل مؤقت:
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
log_dir = tempfile.mkdtemp()
backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
backup_dir = log_dir
)
الآن ، ابدأ تدريب النموذج باستخدام Model.fit
. أثناء التدريب ، سيتم حفظ نقاط التفتيش بفضل backup_restore_callback
المعرّف أعلاه ، بينما يقوم InterruptingCallback
بإثارة استثناء مصطنع لمحاكاة الفشل.
try:
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback, InterruptingCallback()])
except Exception as e:
print(f'{type(e).__name__}:{e}')
Epoch 1/10 1875/1875 [==============================] - 3s 2ms/step - loss: 0.2186 - accuracy: 0.9352 - val_loss: 0.1267 - val_accuracy: 0.9615 Epoch 2/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0967 - accuracy: 0.9700 - val_loss: 0.0910 - val_accuracy: 0.9718 Epoch 3/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0687 - accuracy: 0.9784 - val_loss: 0.0679 - val_accuracy: 0.9797 Epoch 4/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0527 - accuracy: 0.9829 - val_loss: 0.0623 - val_accuracy: 0.9814 Epoch 5/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0434 - accuracy: 0.9857RuntimeError:Interruption
بعد ذلك ، قم بإنشاء مثيل لنموذج Keras ، واستدعاء Model.compile
، واستمر في تدريب النموذج باستخدام Model.fit
من نقطة اختبار محفوظة مسبقًا:
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback])
Epoch 6/10 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0370 - accuracy: 0.9879 - val_loss: 0.0732 - val_accuracy: 0.9791 Epoch 7/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0306 - accuracy: 0.9898 - val_loss: 0.0601 - val_accuracy: 0.9827 Epoch 8/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0259 - accuracy: 0.9913 - val_loss: 0.0655 - val_accuracy: 0.9819 Epoch 9/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0244 - accuracy: 0.9918 - val_loss: 0.0746 - val_accuracy: 0.9812 Epoch 10/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0221 - accuracy: 0.9923 - val_loss: 0.0818 - val_accuracy: 0.9813 <keras.callbacks.History at 0x7fcfe0647350>
TensorFlow 2: اكتب نقاط فحص يدوية باستخدام حلقة تدريب مخصصة
إذا كنت تستخدم حلقة تدريب مخصصة في TensorFlow 2 ، فيمكنك تنفيذ آلية تحمل الأخطاء باستخدام واجهات برمجة تطبيقات tf.train.Checkpoint
و tf.train.CheckpointManager
.
يوضح هذا المثال كيفية:
- استخدم كائن
tf.train.Checkpoint
لإنشاء نقطة تحقق يدويًا ، حيث يتم تعيين الكائنات القابلة للتتبع التي تريد حفظها كسمات. - استخدم
tf.train.CheckpointManager
لإدارة نقاط تفتيش متعددة.
ابدأ بتحديد وإنشاء مثيل نموذج Keras والمحسن ووظيفة الخسارة. بعد ذلك ، قم بإنشاء Checkpoint
يدير كائنين بحالتين قابلتين للتتبع (النموذج والمحسِّن) ، بالإضافة إلى CheckpointManager
لتسجيل وحفظ العديد من نقاط التفتيش في دليل مؤقت.
model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, log_dir, max_to_keep=2)
الآن ، قم بتنفيذ حلقة تدريب مخصصة حيث يتم تحميل آخر نقطة اختبار بعد المرحلة الأولى في كل مرة تبدأ فيها مرحلة جديدة:
for epoch in range(epochs):
if epoch > 0:
tf.train.load_checkpoint(save_path)
print(f"\nStart of epoch {epoch}")
for step in range(steps_per_epoch):
with tf.GradientTape() as tape:
logits = model(x_train, training=True)
loss_value = loss_fn(y_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
save_path = checkpoint_manager.save()
print(f"Checkpoint saved to {save_path}")
print(f"Training loss at step {step}: {loss_value}")
Start of epoch 0 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-1 Training loss at step 0: 2.3636362552642822 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-2 Training loss at step 1: 2.3626415729522705 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-3 Training loss at step 2: 2.3613197803497314 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-4 Training loss at step 3: 2.360600233078003 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-5 Training loss at step 4: 2.3589422702789307 Start of epoch 1 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-6 Training loss at step 0: 2.3563339710235596 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-7 Training loss at step 1: 2.3568854331970215 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-8 Training loss at step 2: 2.354109287261963 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-9 Training loss at step 3: 2.3532731533050537 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-10 Training loss at step 4: 2.351112127304077 Start of epoch 2 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-11 Training loss at step 0: 2.348905563354492 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-12 Training loss at step 1: 2.349478006362915 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-13 Training loss at step 2: 2.3487260341644287 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-14 Training loss at step 3: 2.345991611480713 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-15 Training loss at step 4: 2.3451104164123535 Start of epoch 3 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-16 Training loss at step 0: 2.3441312313079834 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-17 Training loss at step 1: 2.341529130935669 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-18 Training loss at step 2: 2.342329263687134 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-19 Training loss at step 3: 2.340449571609497 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-20 Training loss at step 4: 2.3367927074432373 Start of epoch 4 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-21 Training loss at step 0: 2.3366076946258545 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-22 Training loss at step 1: 2.335028886795044 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-23 Training loss at step 2: 2.3338520526885986 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-24 Training loss at step 3: 2.3345272541046143 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-25 Training loss at step 4: 2.332385301589966
الخطوات التالية
لمعرفة المزيد حول التسامح مع الخطأ ونقاط الفحص في TensorFlow 2 ، ضع في اعتبارك الوثائق التالية:
-
tf.keras.callbacks.BackupAndRestore
مستندات واجهة برمجة تطبيقات رد الاتصال. -
tf.train.Checkpoint
وtf.train.CheckpointManager
API مستندات. - دليل نقاط التفتيش الخاصة بالتدريب ، بما في ذلك قسم نقاط التفتيش في الكتابة .
قد تجد أيضًا المواد التالية المتعلقة بالتدريب الموزع مفيدة:
- قسم تحمل الأخطاء في تدريب متعدد العاملين باستخدام Keras .
- قسم فشل مهمة التسليم في البرنامج التعليمي للتدريب على خادم المعلمة .