ดูบน TensorFlow.org | ทำงานใน Google Colab | ดูแหล่งที่มาบน GitHub | ดาวน์โหลดโน๊ตบุ๊ค |
ความทนทานต่อความผิดพลาดหมายถึงกลไกในการบันทึกสถานะของออบเจ็กต์ที่สามารถติดตามได้เป็นระยะ เช่น พารามิเตอร์และแบบจำลอง ซึ่งจะทำให้คุณสามารถกู้คืนได้ในกรณีที่โปรแกรม/เครื่องขัดข้องระหว่างการฝึก
คู่มือนี้จะสาธิตวิธีเพิ่มความทนทานต่อข้อผิดพลาดให้กับการฝึกอบรมด้วย tf.estimator.Estimator
ใน TensorFlow 1 โดยการระบุการบันทึกเมตริกด้วย tf.estimator.RunConfig
จากนั้น คุณจะได้เรียนรู้วิธีใช้ความทนทานต่อข้อผิดพลาดสำหรับการฝึกอบรมใน Tensorflow 2 ได้สองวิธี:
- หากคุณใช้ Keras
Model.fit
API คุณสามารถส่งการเรียกกลับtf.keras.callbacks.BackupAndRestore
- หากคุณใช้การวนรอบการฝึกแบบกำหนดเอง (ด้วย
tf.GradientTape
) คุณสามารถบันทึกจุดตรวจได้ตามอำเภอใจโดยใช้tf.train.Checkpoint
และtf.train.CheckpointManager
API
ทั้งสองวิธีนี้จะสำรองและกู้คืนสถานะการฝึกในไฟล์ จุดตรวจ
ติดตั้ง
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
TensorFlow 1: บันทึกจุดตรวจด้วย tf.estimator.RunConfig
ใน TensorFlow 1 คุณสามารถกำหนดค่า tf.estimator
เพื่อบันทึกจุดตรวจทุกขั้นตอนโดยการกำหนดค่า tf.estimator.RunConfig
ในตัวอย่างนี้ ให้เริ่มต้นด้วยการเขียนเบ็ดที่โยนข้อผิดพลาดระหว่างจุดตรวจที่ห้า:
class InterruptHook(tf1.train.SessionRunHook):
# A hook for artificially interrupting training.
def begin(self):
self._step = -1
def before_run(self, run_context):
self._step += 1
def after_run(self, run_context, run_values):
if self._step == 5:
raise RuntimeError('Interruption')
ถัดไป กำหนดค่า tf.estimator.Estimator
เพื่อบันทึกทุกจุดตรวจสอบและใช้ชุดข้อมูล MNIST:
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead. WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead.
เริ่มฝึกโมเดล hook ที่คุณกำหนดไว้ก่อนหน้านี้จะยกข้อยกเว้นเทียมขึ้น
try:
classifier.train(input_fn=train_input_fn,
hooks=[InterruptHook()],
max_steps=10)
except Exception as e:
print(f'{type(e).__name__}:{e}')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:65: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py:914: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:loss = 118.92192, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3... INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4... INFO:tensorflow:Saving checkpoints for 4 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5... INFO:tensorflow:Saving checkpoints for 5 into /tmp/tmpv15yxr9g/model.ckpt. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1054: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to delete files with this prefix. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... RuntimeError:Interruption
สร้าง tf.estimator.Estimator
ใหม่โดยใช้จุดตรวจสอบที่บันทึกไว้ล่าสุดและดำเนินการฝึกอบรมต่อไป:
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
classifier.train(input_fn=train_input_fn,
max_steps = 10)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpv15yxr9g/model.ckpt-6 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1161: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file utilities to get mtimes. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7... INFO:tensorflow:Saving checkpoints for 7 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7... INFO:tensorflow:loss = 105.44863, step = 6 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8... INFO:tensorflow:Saving checkpoints for 8 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9... INFO:tensorflow:Saving checkpoints for 9 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 100.47882. <tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7fcfe8165150>
TensorFlow 2: สำรองและกู้คืนด้วยการโทรกลับและ Model.fit
ใน TensorFlow 2 หากคุณใช้ Keras Model.fit
API สำหรับการฝึกอบรม คุณสามารถจัดเตรียมการเรียกกลับ tf.keras.callbacks.BackupAndRestore
เพื่อเพิ่มฟังก์ชันความทนทานต่อข้อบกพร่อง
เพื่อช่วยสาธิตสิ่งนี้ เรามาเริ่มด้วยการกำหนดคลาสการเรียกกลับที่ส่งข้อผิดพลาดเทียมระหว่างจุดตรวจที่ห้า:
class InterruptingCallback(tf.keras.callbacks.Callback):
# A callback for artificially interrupting training.
def on_epoch_end(self, epoch, log=None):
if epoch == 4:
raise RuntimeError('Interruption')
จากนั้น กำหนดและสร้างโมเดล Keras อย่างง่าย กำหนดฟังก์ชันการสูญเสีย เรียก Model.compile
และตั้งค่าการเรียกกลับ tf.keras.callbacks.BackupAndRestore
ที่จะบันทึกจุดตรวจสอบในไดเร็กทอรีชั่วคราว:
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
log_dir = tempfile.mkdtemp()
backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
backup_dir = log_dir
)
ตอนนี้ เริ่มฝึกโมเดลด้วย Model.fit
ระหว่างการฝึก จุดตรวจจะถูกบันทึกไว้ด้วย backup_restore_callback
ที่กำหนดไว้ข้างต้น ในขณะที่ InterruptingCallback
จะเพิ่มข้อยกเว้นเทียมเพื่อจำลองความล้มเหลว
try:
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback, InterruptingCallback()])
except Exception as e:
print(f'{type(e).__name__}:{e}')
Epoch 1/10 1875/1875 [==============================] - 3s 2ms/step - loss: 0.2186 - accuracy: 0.9352 - val_loss: 0.1267 - val_accuracy: 0.9615 Epoch 2/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0967 - accuracy: 0.9700 - val_loss: 0.0910 - val_accuracy: 0.9718 Epoch 3/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0687 - accuracy: 0.9784 - val_loss: 0.0679 - val_accuracy: 0.9797 Epoch 4/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0527 - accuracy: 0.9829 - val_loss: 0.0623 - val_accuracy: 0.9814 Epoch 5/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0434 - accuracy: 0.9857RuntimeError:Interruption
ถัดไป สร้างโมเดล Keras เรียก Model.compile
และฝึกโมเดลต่อด้วย Model.fit
จากจุดตรวจสอบที่บันทึกไว้ก่อนหน้านี้:
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback])
Epoch 6/10 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0370 - accuracy: 0.9879 - val_loss: 0.0732 - val_accuracy: 0.9791 Epoch 7/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0306 - accuracy: 0.9898 - val_loss: 0.0601 - val_accuracy: 0.9827 Epoch 8/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0259 - accuracy: 0.9913 - val_loss: 0.0655 - val_accuracy: 0.9819 Epoch 9/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0244 - accuracy: 0.9918 - val_loss: 0.0746 - val_accuracy: 0.9812 Epoch 10/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0221 - accuracy: 0.9923 - val_loss: 0.0818 - val_accuracy: 0.9813 <keras.callbacks.History at 0x7fcfe0647350>
TensorFlow 2: เขียนจุดตรวจสอบด้วยตนเองด้วยลูปการฝึกแบบกำหนดเอง
หากคุณใช้การวนรอบการฝึกแบบกำหนดเองใน TensorFlow 2 คุณสามารถใช้กลไกความทนทานต่อข้อผิดพลาดกับ tf.train.Checkpoint
และ tf.train.CheckpointManager
API ได้
ตัวอย่างนี้สาธิตวิธีการ:
- ใช้อ็อบเจ็กต์
tf.train.Checkpoint
เพื่อสร้างจุดตรวจด้วยตนเอง โดยที่ออบเจ็กต์ที่ติดตามได้ที่คุณต้องการบันทึกถูกตั้งค่าเป็นแอตทริบิวต์ - ใช้
tf.train.CheckpointManager
เพื่อจัดการจุดตรวจหลายจุด
เริ่มต้นด้วยการกำหนดและสร้างอินสแตนซ์โมเดล Keras เครื่องมือเพิ่มประสิทธิภาพ และฟังก์ชันการสูญเสีย จากนั้น สร้าง Checkpoint
ที่จัดการสองวัตถุด้วยสถานะที่สามารถติดตามได้ (แบบจำลองและเครื่องมือเพิ่มประสิทธิภาพ) รวมถึง CheckpointManager
สำหรับการบันทึกและเก็บรักษาจุดตรวจสอบหลายจุดในไดเรกทอรีชั่วคราว
model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, log_dir, max_to_keep=2)
ตอนนี้ ใช้ลูปการฝึกแบบกำหนดเองซึ่งหลังจากยุคแรกทุกครั้งที่มีการโหลดยุคใหม่จุดตรวจสุดท้ายจะถูกโหลด:
for epoch in range(epochs):
if epoch > 0:
tf.train.load_checkpoint(save_path)
print(f"\nStart of epoch {epoch}")
for step in range(steps_per_epoch):
with tf.GradientTape() as tape:
logits = model(x_train, training=True)
loss_value = loss_fn(y_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
save_path = checkpoint_manager.save()
print(f"Checkpoint saved to {save_path}")
print(f"Training loss at step {step}: {loss_value}")
Start of epoch 0 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-1 Training loss at step 0: 2.3636362552642822 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-2 Training loss at step 1: 2.3626415729522705 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-3 Training loss at step 2: 2.3613197803497314 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-4 Training loss at step 3: 2.360600233078003 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-5 Training loss at step 4: 2.3589422702789307 Start of epoch 1 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-6 Training loss at step 0: 2.3563339710235596 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-7 Training loss at step 1: 2.3568854331970215 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-8 Training loss at step 2: 2.354109287261963 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-9 Training loss at step 3: 2.3532731533050537 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-10 Training loss at step 4: 2.351112127304077 Start of epoch 2 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-11 Training loss at step 0: 2.348905563354492 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-12 Training loss at step 1: 2.349478006362915 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-13 Training loss at step 2: 2.3487260341644287 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-14 Training loss at step 3: 2.345991611480713 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-15 Training loss at step 4: 2.3451104164123535 Start of epoch 3 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-16 Training loss at step 0: 2.3441312313079834 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-17 Training loss at step 1: 2.341529130935669 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-18 Training loss at step 2: 2.342329263687134 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-19 Training loss at step 3: 2.340449571609497 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-20 Training loss at step 4: 2.3367927074432373 Start of epoch 4 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-21 Training loss at step 0: 2.3366076946258545 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-22 Training loss at step 1: 2.335028886795044 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-23 Training loss at step 2: 2.3338520526885986 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-24 Training loss at step 3: 2.3345272541046143 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-25 Training loss at step 4: 2.332385301589966
ขั้นตอนถัดไป
หากต้องการเรียนรู้เพิ่มเติมเกี่ยวกับความทนทานต่อข้อผิดพลาดและจุดตรวจสอบใน TensorFlow 2 ให้พิจารณาเอกสารต่อไปนี้:
- เอกสาร
tf.keras.callbacks.BackupAndRestore
callback API - เอกสาร
tf.train.Checkpoint
และtf.train.CheckpointManager
API - คู่มือ จุดตรวจการฝึกอบรม รวมถึงส่วน จุดตรวจการเขียน
คุณอาจพบว่าเนื้อหาที่เกี่ยวข้องกับ การฝึกอบรมแบบกระจาย ต่อไปนี้มีประโยชน์:
- ส่วน ความทนทานต่อข้อผิดพลาด ในการ ฝึกอบรมผู้ปฏิบัติงานหลายคนพร้อมบทช่วยสอน Keras
- ส่วน ความล้มเหลวของงานมอบหมาย ในบทช่วยสอน การฝึกอบรมเซิร์ฟเวอร์พารามิเตอร์