Di chuyển cơ chế chịu lỗi

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Khả năng chịu lỗi đề cập đến cơ chế lưu định kỳ trạng thái của các đối tượng có thể theo dõi, chẳng hạn như các tham số và mô hình. Điều này cho phép bạn khôi phục chúng trong trường hợp chương trình / máy bị lỗi trong quá trình đào tạo.

Hướng dẫn này đầu tiên trình bày cách thêm khả năng chịu lỗi vào đào tạo với tf.estimator.Estimator trong TensorFlow 1 bằng cách chỉ định lưu số liệu với tf.estimator.RunConfig . Sau đó, bạn sẽ học cách triển khai khả năng chịu lỗi cho đào tạo trong Tensorflow 2 theo hai cách:

Cả hai phương pháp này sẽ sao lưu và khôi phục trạng thái đào tạo trong các tệp điểm kiểm tra .

Thành lập

import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

TensorFlow 1: Lưu các điểm kiểm tra với tf.estimator.RunConfig

Trong TensorFlow 1, bạn có thể định cấu hình tf.estimator để lưu các điểm kiểm tra mỗi bước bằng cách định cấu hình tf.estimator.RunConfig .

Trong ví dụ này, hãy bắt đầu bằng cách viết một hook tạo ra một lỗi giả tạo trong lần kiểm tra thứ năm:

class InterruptHook(tf1.train.SessionRunHook):
  # A hook for artificially interrupting training.
  def begin(self):
    self._step = -1

  def before_run(self, run_context):
    self._step += 1

  def after_run(self, run_context, run_values):
    if self._step == 5:
      raise RuntimeError('Interruption')

Tiếp theo, định cấu hình tf.estimator.Estimator để lưu mọi điểm kiểm tra và sử dụng tập dữ liệu MNIST:

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
                                 save_checkpoints_steps=1)

path = tempfile.mkdtemp()

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)

train_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_train},
    y=y_train.astype(np.int32),
    num_epochs=10,
    batch_size=50,
    shuffle=True,
)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.

WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead.

Bắt đầu đào tạo mô hình. Một ngoại lệ nhân tạo sẽ được nâng lên bởi hook mà bạn đã xác định trước đó.

try:
  classifier.train(input_fn=train_input_fn,
                   hooks=[InterruptHook()],
                   max_steps=10)
except Exception as e:
  print(f'{type(e).__name__}:{e}')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:65: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py:914: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:loss = 118.92192, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2...
INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4...
INFO:tensorflow:Saving checkpoints for 4 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5...
INFO:tensorflow:Saving checkpoints for 5 into /tmp/tmpv15yxr9g/model.ckpt.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1054: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
RuntimeError:Interruption

Xây dựng lại tf.estimator.Estimator bằng cách sử dụng điểm kiểm tra đã lưu cuối cùng và tiếp tục đào tạo:

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)
classifier.train(input_fn=train_input_fn,
                   max_steps = 10)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpv15yxr9g/model.ckpt-6
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1161: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file utilities to get mtimes.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7...
INFO:tensorflow:Saving checkpoints for 7 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7...
INFO:tensorflow:loss = 105.44863, step = 6
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8...
INFO:tensorflow:Saving checkpoints for 8 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9...
INFO:tensorflow:Saving checkpoints for 9 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 100.47882.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7fcfe8165150>

TensorFlow 2: Sao lưu và khôi phục bằng callback và Model.fit

Trong TensorFlow 2, nếu bạn sử dụng API Model.fit để đào tạo, bạn có thể cung cấp lệnh gọi lại tf.keras.callbacks.BackupAndRestore để thêm chức năng chịu lỗi.

Để giúp chứng minh điều này, trước tiên hãy bắt đầu bằng cách xác định một lớp gọi lại tạo lỗi giả tạo trong điểm kiểm tra thứ năm:

class InterruptingCallback(tf.keras.callbacks.Callback):
  # A callback for artificially interrupting training.
  def on_epoch_end(self, epoch, log=None):
    if epoch == 4:
      raise RuntimeError('Interruption')

Sau đó, xác định và khởi tạo một mô hình Keras đơn giản, xác định hàm mất mát, gọi Model.compile và thiết lập một lệnh gọi lại tf.keras.callbacks.BackupAndRestore sẽ lưu các điểm kiểm tra trong một thư mục tạm thời:

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
  ])

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)

log_dir = tempfile.mkdtemp()

backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
    backup_dir = log_dir
)

Bây giờ, hãy bắt đầu đào tạo người mẫu với Model.fit . Trong quá trình huấn luyện, các điểm kiểm tra sẽ được lưu nhờ backup_restore_callback được định nghĩa ở trên, trong khi InterruptingCallback sẽ đưa ra một ngoại lệ nhân tạo để mô phỏng một lỗi.

try:
  model.fit(x=x_train,
            y=y_train,
            epochs=10,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback, InterruptingCallback()])
except Exception as e:
  print(f'{type(e).__name__}:{e}')
Epoch 1/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.2186 - accuracy: 0.9352 - val_loss: 0.1267 - val_accuracy: 0.9615
Epoch 2/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0967 - accuracy: 0.9700 - val_loss: 0.0910 - val_accuracy: 0.9718
Epoch 3/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0687 - accuracy: 0.9784 - val_loss: 0.0679 - val_accuracy: 0.9797
Epoch 4/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0527 - accuracy: 0.9829 - val_loss: 0.0623 - val_accuracy: 0.9814
Epoch 5/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0434 - accuracy: 0.9857RuntimeError:Interruption

Tiếp theo, khởi tạo mô hình Keras, gọi Model.compile và tiếp tục đào tạo mô hình với Model.fit từ một điểm kiểm tra đã lưu trước đó:

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)
model.fit(x=x_train,
            y=y_train,
            epochs=10,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback])
Epoch 6/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0370 - accuracy: 0.9879 - val_loss: 0.0732 - val_accuracy: 0.9791
Epoch 7/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0306 - accuracy: 0.9898 - val_loss: 0.0601 - val_accuracy: 0.9827
Epoch 8/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0259 - accuracy: 0.9913 - val_loss: 0.0655 - val_accuracy: 0.9819
Epoch 9/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0244 - accuracy: 0.9918 - val_loss: 0.0746 - val_accuracy: 0.9812
Epoch 10/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0221 - accuracy: 0.9923 - val_loss: 0.0818 - val_accuracy: 0.9813
<keras.callbacks.History at 0x7fcfe0647350>

TensorFlow 2: Viết các điểm kiểm tra thủ công với vòng lặp đào tạo tùy chỉnh

Nếu bạn sử dụng vòng lặp đào tạo tùy chỉnh trong TensorFlow 2, bạn có thể triển khai cơ chế chịu lỗi với các API tf.train.Checkpointtf.train.CheckpointManager .

Ví dụ này minh họa cách:

  • Sử dụng đối tượng tf.train.Checkpoint để tạo thủ công một điểm kiểm tra, trong đó các đối tượng có thể theo dõi bạn muốn lưu được đặt làm thuộc tính.
  • Sử dụng tf.train.CheckpointManager để quản lý nhiều điểm kiểm tra.

Bắt đầu bằng cách xác định và khởi tạo mô hình Keras, trình tối ưu hóa và hàm mất mát. Sau đó, tạo một Checkpoint quản lý hai đối tượng có trạng thái có thể theo dõi (mô hình và trình tối ưu hóa), cũng như một CheckpointManager để ghi nhật ký và giữ một số điểm kiểm tra trong một thư mục tạm thời.

model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5

checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
            checkpoint, log_dir, max_to_keep=2)

Bây giờ, hãy triển khai một vòng lặp đào tạo tùy chỉnh trong đó sau kỷ nguyên đầu tiên mỗi khi kỷ nguyên mới bắt đầu, điểm kiểm tra cuối cùng được tải:

for epoch in range(epochs):
  if epoch > 0:
      tf.train.load_checkpoint(save_path)
  print(f"\nStart of epoch {epoch}")

  for step in range(steps_per_epoch):
    with tf.GradientTape() as tape:

      logits = model(x_train, training=True)
      loss_value = loss_fn(y_train, logits)

      grads = tape.gradient(loss_value, model.trainable_weights)
      optimizer.apply_gradients(zip(grads, model.trainable_weights))

    save_path = checkpoint_manager.save()
    print(f"Checkpoint saved to {save_path}")
    print(f"Training loss at step {step}: {loss_value}")
Start of epoch 0
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-1
Training loss at step 0: 2.3636362552642822
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-2
Training loss at step 1: 2.3626415729522705
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-3
Training loss at step 2: 2.3613197803497314
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-4
Training loss at step 3: 2.360600233078003
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-5
Training loss at step 4: 2.3589422702789307

Start of epoch 1
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-6
Training loss at step 0: 2.3563339710235596
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-7
Training loss at step 1: 2.3568854331970215
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-8
Training loss at step 2: 2.354109287261963
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-9
Training loss at step 3: 2.3532731533050537
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-10
Training loss at step 4: 2.351112127304077

Start of epoch 2
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-11
Training loss at step 0: 2.348905563354492
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-12
Training loss at step 1: 2.349478006362915
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-13
Training loss at step 2: 2.3487260341644287
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-14
Training loss at step 3: 2.345991611480713
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-15
Training loss at step 4: 2.3451104164123535

Start of epoch 3
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-16
Training loss at step 0: 2.3441312313079834
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-17
Training loss at step 1: 2.341529130935669
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-18
Training loss at step 2: 2.342329263687134
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-19
Training loss at step 3: 2.340449571609497
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-20
Training loss at step 4: 2.3367927074432373

Start of epoch 4
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-21
Training loss at step 0: 2.3366076946258545
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-22
Training loss at step 1: 2.335028886795044
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-23
Training loss at step 2: 2.3338520526885986
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-24
Training loss at step 3: 2.3345272541046143
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-25
Training loss at step 4: 2.332385301589966

Bước tiếp theo

Để tìm hiểu thêm về khả năng chịu lỗi và kiểm tra trong TensorFlow 2, hãy xem tài liệu sau:

Bạn cũng có thể thấy hữu ích tài liệu sau liên quan đến đào tạo phân tán :