Migrasi mekanisme toleransi kesalahan

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Toleransi kesalahan mengacu pada mekanisme penyimpanan status objek yang dapat dilacak secara berkala, seperti parameter dan model. Ini memungkinkan Anda untuk memulihkannya jika terjadi kegagalan program/mesin selama pelatihan.

Panduan ini pertama-tama menunjukkan cara menambahkan toleransi kesalahan ke pelatihan dengan tf.estimator.Estimator di TensorFlow 1 dengan menentukan penghematan metrik dengan tf.estimator.RunConfig . Kemudian, Anda akan mempelajari cara menerapkan toleransi kesalahan untuk pelatihan di Tensorflow 2 dengan dua cara:

Kedua metode ini akan mencadangkan dan memulihkan status pelatihan dalam file pos pemeriksaan .

Mempersiapkan

import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

TensorFlow 1: Simpan pos pemeriksaan dengan tf.estimator.RunConfig

Di TensorFlow 1, Anda dapat mengonfigurasi tf.estimator untuk menyimpan pos pemeriksaan setiap langkah dengan mengonfigurasi tf.estimator.RunConfig .

Dalam contoh ini, mulailah dengan menulis kail yang secara artifisial melempar kesalahan selama pos pemeriksaan kelima:

class InterruptHook(tf1.train.SessionRunHook):
  # A hook for artificially interrupting training.
  def begin(self):
    self._step = -1

  def before_run(self, run_context):
    self._step += 1

  def after_run(self, run_context, run_values):
    if self._step == 5:
      raise RuntimeError('Interruption')

Selanjutnya, konfigurasikan tf.estimator.Estimator untuk menyimpan setiap pos pemeriksaan dan menggunakan kumpulan data MNIST:

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
                                 save_checkpoints_steps=1)

path = tempfile.mkdtemp()

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)

train_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_train},
    y=y_train.astype(np.int32),
    num_epochs=10,
    batch_size=50,
    shuffle=True,
)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.

WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead.

Mulailah melatih model. Pengecualian buatan akan dimunculkan oleh kait yang Anda tentukan sebelumnya.

try:
  classifier.train(input_fn=train_input_fn,
                   hooks=[InterruptHook()],
                   max_steps=10)
except Exception as e:
  print(f'{type(e).__name__}:{e}')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:65: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py:914: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:loss = 118.92192, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2...
INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4...
INFO:tensorflow:Saving checkpoints for 4 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5...
INFO:tensorflow:Saving checkpoints for 5 into /tmp/tmpv15yxr9g/model.ckpt.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1054: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
RuntimeError:Interruption

Bangun kembali tf.estimator.Estimator menggunakan pos pemeriksaan terakhir yang disimpan dan lanjutkan pelatihan:

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)
classifier.train(input_fn=train_input_fn,
                   max_steps = 10)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpv15yxr9g/model.ckpt-6
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1161: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file utilities to get mtimes.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7...
INFO:tensorflow:Saving checkpoints for 7 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7...
INFO:tensorflow:loss = 105.44863, step = 6
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8...
INFO:tensorflow:Saving checkpoints for 8 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9...
INFO:tensorflow:Saving checkpoints for 9 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpv15yxr9g/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 100.47882.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7fcfe8165150>

TensorFlow 2: Cadangkan dan pulihkan dengan panggilan balik dan Model.fit

Di TensorFlow 2, jika Anda menggunakan Keras Model.fit API untuk pelatihan, Anda dapat menyediakan panggilan balik tf.keras.callbacks.BackupAndRestore untuk menambahkan fungsionalitas toleransi kesalahan.

Untuk membantu mendemonstrasikan ini, pertama-tama mari kita mulai dengan mendefinisikan kelas panggilan balik yang secara artifisial memunculkan kesalahan selama pos pemeriksaan kelima:

class InterruptingCallback(tf.keras.callbacks.Callback):
  # A callback for artificially interrupting training.
  def on_epoch_end(self, epoch, log=None):
    if epoch == 4:
      raise RuntimeError('Interruption')

Kemudian, tentukan dan buat contoh model Keras sederhana, tentukan fungsi kehilangan, panggil Model.compile , dan siapkan panggilan balik tf.keras.callbacks.BackupAndRestore yang akan menyimpan pos pemeriksaan di direktori sementara:

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
  ])

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)

log_dir = tempfile.mkdtemp()

backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
    backup_dir = log_dir
)

Sekarang, mulailah melatih model dengan Model.fit . Selama pelatihan, pos pemeriksaan akan disimpan berkat backup_restore_callback yang ditentukan di atas, sedangkan InterruptingCallback akan memunculkan pengecualian buatan untuk mensimulasikan kegagalan.

try:
  model.fit(x=x_train,
            y=y_train,
            epochs=10,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback, InterruptingCallback()])
except Exception as e:
  print(f'{type(e).__name__}:{e}')
Epoch 1/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.2186 - accuracy: 0.9352 - val_loss: 0.1267 - val_accuracy: 0.9615
Epoch 2/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0967 - accuracy: 0.9700 - val_loss: 0.0910 - val_accuracy: 0.9718
Epoch 3/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0687 - accuracy: 0.9784 - val_loss: 0.0679 - val_accuracy: 0.9797
Epoch 4/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0527 - accuracy: 0.9829 - val_loss: 0.0623 - val_accuracy: 0.9814
Epoch 5/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0434 - accuracy: 0.9857RuntimeError:Interruption

Selanjutnya, buat instance model Keras, panggil Model.compile , dan lanjutkan melatih model dengan Model.fit dari pos pemeriksaan yang disimpan sebelumnya:

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10)
model.fit(x=x_train,
            y=y_train,
            epochs=10,
            validation_data=(x_test, y_test),
            callbacks=[backup_restore_callback])
Epoch 6/10
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0370 - accuracy: 0.9879 - val_loss: 0.0732 - val_accuracy: 0.9791
Epoch 7/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0306 - accuracy: 0.9898 - val_loss: 0.0601 - val_accuracy: 0.9827
Epoch 8/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0259 - accuracy: 0.9913 - val_loss: 0.0655 - val_accuracy: 0.9819
Epoch 9/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0244 - accuracy: 0.9918 - val_loss: 0.0746 - val_accuracy: 0.9812
Epoch 10/10
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0221 - accuracy: 0.9923 - val_loss: 0.0818 - val_accuracy: 0.9813
<keras.callbacks.History at 0x7fcfe0647350>

TensorFlow 2: Tulis pos pemeriksaan manual dengan loop pelatihan khusus

Jika menggunakan loop pelatihan kustom di TensorFlow 2, Anda dapat menerapkan mekanisme toleransi kesalahan dengan tf.train.Checkpoint dan tf.train.CheckpointManager API.

Contoh ini menunjukkan cara:

  • Gunakan objek tf.train.Checkpoint untuk membuat pos pemeriksaan secara manual, di mana objek yang dapat dilacak yang ingin Anda simpan ditetapkan sebagai atribut.
  • Gunakan tf.train.CheckpointManager untuk mengelola beberapa pos pemeriksaan.

Mulailah dengan mendefinisikan dan membuat instance model Keras, pengoptimal, dan fungsi kerugian. Kemudian, buat Checkpoint yang mengelola dua objek dengan status yang dapat dilacak (model dan pengoptimal), serta CheckpointManager untuk mencatat dan menyimpan beberapa pos pemeriksaan di direktori sementara.

model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5

checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
            checkpoint, log_dir, max_to_keep=2)

Sekarang, terapkan loop pelatihan khusus di mana setelah epoch pertama setiap kali epoch baru dimulai, pos pemeriksaan terakhir dimuat:

for epoch in range(epochs):
  if epoch > 0:
      tf.train.load_checkpoint(save_path)
  print(f"\nStart of epoch {epoch}")

  for step in range(steps_per_epoch):
    with tf.GradientTape() as tape:

      logits = model(x_train, training=True)
      loss_value = loss_fn(y_train, logits)

      grads = tape.gradient(loss_value, model.trainable_weights)
      optimizer.apply_gradients(zip(grads, model.trainable_weights))

    save_path = checkpoint_manager.save()
    print(f"Checkpoint saved to {save_path}")
    print(f"Training loss at step {step}: {loss_value}")
Start of epoch 0
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-1
Training loss at step 0: 2.3636362552642822
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-2
Training loss at step 1: 2.3626415729522705
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-3
Training loss at step 2: 2.3613197803497314
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-4
Training loss at step 3: 2.360600233078003
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-5
Training loss at step 4: 2.3589422702789307

Start of epoch 1
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-6
Training loss at step 0: 2.3563339710235596
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-7
Training loss at step 1: 2.3568854331970215
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-8
Training loss at step 2: 2.354109287261963
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-9
Training loss at step 3: 2.3532731533050537
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-10
Training loss at step 4: 2.351112127304077

Start of epoch 2
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-11
Training loss at step 0: 2.348905563354492
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-12
Training loss at step 1: 2.349478006362915
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-13
Training loss at step 2: 2.3487260341644287
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-14
Training loss at step 3: 2.345991611480713
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-15
Training loss at step 4: 2.3451104164123535

Start of epoch 3
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-16
Training loss at step 0: 2.3441312313079834
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-17
Training loss at step 1: 2.341529130935669
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-18
Training loss at step 2: 2.342329263687134
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-19
Training loss at step 3: 2.340449571609497
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-20
Training loss at step 4: 2.3367927074432373

Start of epoch 4
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-21
Training loss at step 0: 2.3366076946258545
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-22
Training loss at step 1: 2.335028886795044
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-23
Training loss at step 2: 2.3338520526885986
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-24
Training loss at step 3: 2.3345272541046143
Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-25
Training loss at step 4: 2.332385301589966

Langkah selanjutnya

Untuk mempelajari lebih lanjut tentang toleransi kesalahan dan pemeriksaan di TensorFlow 2, pertimbangkan dokumentasi berikut:

Anda juga dapat menemukan materi berikut yang terkait dengan pelatihan terdistribusi yang berguna: