TensorFlow.org'da görüntüleyin | Google Colab'da çalıştırın | Kaynağı GitHub'da görüntüleyin | Not defterini indir |
Hata toleransı, parametreler ve modeller gibi izlenebilir nesnelerin durumlarını periyodik olarak kaydetme mekanizmasını ifade eder. Bu, eğitim sırasında bir program/makine arızası durumunda onları kurtarmanızı sağlar.
Bu kılavuz ilk olarak, tf.estimator.RunConfig ile metrik kaydetmeyi belirterek TensorFlow 1'de tf.estimator.Estimator
ile eğitime hata toleransının nasıl ekleneceğini tf.estimator.RunConfig
. Ardından, Tensorflow 2'de eğitim için hata toleransını iki şekilde nasıl uygulayacağınızı öğreneceksiniz:
-
Model.fit
API'sini kullanıyorsanız,tf.keras.callbacks.BackupAndRestore
geri aramasını buna iletebilirsiniz. - Özel bir eğitim döngüsü kullanıyorsanız (
tf.GradientTape
ile),tf.train.Checkpoint
vetf.train.CheckpointManager
API'lerini kullanarak kontrol noktalarını keyfi olarak kaydedebilirsiniz.
Bu yöntemlerin her ikisi de denetim noktası dosyalarındaki eğitim durumlarını yedekleyecek ve geri yükleyecektir.
Kurmak
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
tutucu1 l10n-yermnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
TensorFlow 1: Kontrol noktalarını tf.estimator.RunConfig ile kaydedin
TensorFlow 1'de, tf.estimator.RunConfig
öğesini yapılandırarak her adımda kontrol noktalarını kaydetmek için bir tf.estimator
yapılandırabilirsiniz.
Bu örnekte, beşinci kontrol noktası sırasında yapay olarak hata veren bir kanca yazarak başlayın:
class InterruptHook(tf1.train.SessionRunHook):
# A hook for artificially interrupting training.
def begin(self):
self._step = -1
def before_run(self, run_context):
self._step += 1
def after_run(self, run_context, run_values):
if self._step == 5:
raise RuntimeError('Interruption')
Ardından, her kontrol noktasını kaydetmek ve MNIST veri kümesini kullanmak için tf.estimator.Estimator
yapılandırın:
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
tutucu4 l10n-yerINFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead. WARNING:tensorflow:From /tmp/ipykernel_20837/314197976.py:17: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead.
Modeli eğitmeye başlayın. Daha önce tanımladığınız kanca tarafından yapay bir istisna oluşturulacaktır.
try:
classifier.train(input_fn=train_input_fn,
hooks=[InterruptHook()],
max_steps=10)
except Exception as e:
print(f'{type(e).__name__}:{e}')
tutucu6 l10n-yerWARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:65: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py:914: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:loss = 118.92192, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3... INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4... INFO:tensorflow:Saving checkpoints for 4 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5... INFO:tensorflow:Saving checkpoints for 5 into /tmp/tmpv15yxr9g/model.ckpt. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1054: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to delete files with this prefix. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... RuntimeError:Interruption
Son kaydedilen kontrol noktasını kullanarak tf.estimator.Estimator
yeniden oluşturun ve eğitime devam edin:
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
classifier.train(input_fn=train_input_fn,
max_steps = 10)
tutucu8 l10n-yerINFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpv15yxr9g', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpv15yxr9g/model.ckpt-6 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1161: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file utilities to get mtimes. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7... INFO:tensorflow:Saving checkpoints for 7 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7... INFO:tensorflow:loss = 105.44863, step = 6 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8... INFO:tensorflow:Saving checkpoints for 8 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9... INFO:tensorflow:Saving checkpoints for 9 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpv15yxr9g/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 100.47882. <tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7fcfe8165150>
TensorFlow 2: Geri arama ve Model.fit ile yedekleyin ve geri yükleyin
TensorFlow 2'de, eğitim için Model.fit
API'sini kullanıyorsanız, hata toleransı işlevini eklemek için tf.keras.callbacks.BackupAndRestore
geri aramasını sağlayabilirsiniz.
Bunu göstermeye yardımcı olmak için, ilk olarak beşinci kontrol noktasında yapay olarak hata veren bir geri çağırma sınıfı tanımlayarak başlayalım:
class InterruptingCallback(tf.keras.callbacks.Callback):
# A callback for artificially interrupting training.
def on_epoch_end(self, epoch, log=None):
if epoch == 4:
raise RuntimeError('Interruption')
Ardından, basit bir Keras modeli tanımlayın ve somutlaştırın, kayıp işlevini tanımlayın, Model.compile
çağırın ve kontrol noktalarını geçici bir dizine kaydedecek bir tf.keras.callbacks.BackupAndRestore
geri çağrısını ayarlayın:
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
log_dir = tempfile.mkdtemp()
backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
backup_dir = log_dir
)
Şimdi modeli Model.fit
ile eğitmeye başlayın. Eğitim sırasında, yukarıda tanımlanan backup_restore_callback
sayesinde kontrol noktaları kaydedilirken, InterruptingCallback
bir arızayı simüle etmek için yapay bir istisna oluşturur.
try:
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback, InterruptingCallback()])
except Exception as e:
print(f'{type(e).__name__}:{e}')
tutucu12 l10n-yerEpoch 1/10 1875/1875 [==============================] - 3s 2ms/step - loss: 0.2186 - accuracy: 0.9352 - val_loss: 0.1267 - val_accuracy: 0.9615 Epoch 2/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0967 - accuracy: 0.9700 - val_loss: 0.0910 - val_accuracy: 0.9718 Epoch 3/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0687 - accuracy: 0.9784 - val_loss: 0.0679 - val_accuracy: 0.9797 Epoch 4/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0527 - accuracy: 0.9829 - val_loss: 0.0623 - val_accuracy: 0.9814 Epoch 5/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0434 - accuracy: 0.9857RuntimeError:Interruption
Ardından, Keras modelini somutlaştırın, Model.compile
çağırın ve modeli önceden kaydedilmiş bir kontrol noktasından Model.fit
ile eğitmeye devam edin:
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback])
tutucu14 l10n-yerEpoch 6/10 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0370 - accuracy: 0.9879 - val_loss: 0.0732 - val_accuracy: 0.9791 Epoch 7/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0306 - accuracy: 0.9898 - val_loss: 0.0601 - val_accuracy: 0.9827 Epoch 8/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0259 - accuracy: 0.9913 - val_loss: 0.0655 - val_accuracy: 0.9819 Epoch 9/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0244 - accuracy: 0.9918 - val_loss: 0.0746 - val_accuracy: 0.9812 Epoch 10/10 1875/1875 [==============================] - 2s 1ms/step - loss: 0.0221 - accuracy: 0.9923 - val_loss: 0.0818 - val_accuracy: 0.9813 <keras.callbacks.History at 0x7fcfe0647350>
TensorFlow 2: Özel bir eğitim döngüsüyle manuel kontrol noktaları yazın
TensorFlow 2'de özel bir eğitim döngüsü kullanıyorsanız, tf.train.Checkpoint
ve tf.train.CheckpointManager
API'leriyle bir hataya dayanıklılık mekanizması uygulayabilirsiniz.
Bu örnek, aşağıdakilerin nasıl yapılacağını gösterir:
- Kaydetmek istediğiniz izlenebilir nesnelerin nitelik olarak ayarlandığı bir kontrol noktası oluşturmak için bir
tf.train.Checkpoint
nesnesi kullanın. - Birden çok kontrol noktasını yönetmek için bir
tf.train.CheckpointManager
kullanın.
Keras modelini, optimize ediciyi ve kayıp işlevini tanımlayarak ve somutlaştırarak başlayın. Ardından, izlenebilir durumlara sahip iki nesneyi (model ve optimize edici) yöneten bir Checkpoint
ve ayrıca birkaç kontrol noktasını geçici bir dizinde günlüğe kaydetmek ve tutmak için bir CheckpointManager
oluşturun.
model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, log_dir, max_to_keep=2)
Şimdi, ilk çağdan sonra her yeni çağ başladığında son kontrol noktasının yüklendiği özel bir eğitim döngüsü uygulayın:
for epoch in range(epochs):
if epoch > 0:
tf.train.load_checkpoint(save_path)
print(f"\nStart of epoch {epoch}")
for step in range(steps_per_epoch):
with tf.GradientTape() as tape:
logits = model(x_train, training=True)
loss_value = loss_fn(y_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
save_path = checkpoint_manager.save()
print(f"Checkpoint saved to {save_path}")
print(f"Training loss at step {step}: {loss_value}")
tutucu17 l10n-yerStart of epoch 0 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-1 Training loss at step 0: 2.3636362552642822 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-2 Training loss at step 1: 2.3626415729522705 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-3 Training loss at step 2: 2.3613197803497314 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-4 Training loss at step 3: 2.360600233078003 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-5 Training loss at step 4: 2.3589422702789307 Start of epoch 1 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-6 Training loss at step 0: 2.3563339710235596 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-7 Training loss at step 1: 2.3568854331970215 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-8 Training loss at step 2: 2.354109287261963 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-9 Training loss at step 3: 2.3532731533050537 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-10 Training loss at step 4: 2.351112127304077 Start of epoch 2 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-11 Training loss at step 0: 2.348905563354492 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-12 Training loss at step 1: 2.349478006362915 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-13 Training loss at step 2: 2.3487260341644287 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-14 Training loss at step 3: 2.345991611480713 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-15 Training loss at step 4: 2.3451104164123535 Start of epoch 3 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-16 Training loss at step 0: 2.3441312313079834 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-17 Training loss at step 1: 2.341529130935669 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-18 Training loss at step 2: 2.342329263687134 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-19 Training loss at step 3: 2.340449571609497 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-20 Training loss at step 4: 2.3367927074432373 Start of epoch 4 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-21 Training loss at step 0: 2.3366076946258545 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-22 Training loss at step 1: 2.335028886795044 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-23 Training loss at step 2: 2.3338520526885986 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-24 Training loss at step 3: 2.3345272541046143 Checkpoint saved to /tmp/tmpnr4ss2g8/ckpt-25 Training loss at step 4: 2.332385301589966
Sonraki adımlar
TensorFlow 2'de hata toleransı ve kontrol noktası hakkında daha fazla bilgi edinmek için aşağıdaki belgeleri inceleyin:
-
tf.keras.callbacks.BackupAndRestore
geri arama API'si belgeleri. -
tf.train.Checkpoint
vetf.train.CheckpointManager
API belgeleri. - Kontrol noktaları yazma bölümü de dahil olmak üzere Eğitim kontrol noktaları kılavuzu.
Dağıtılmış eğitimle ilgili aşağıdaki materyalleri de faydalı bulabilirsiniz:
- Keras öğreticisiyle Çoklu çalışan eğitimindeki Hata toleransı bölümü.
- Parametre sunucusu eğitim öğreticisindeki Teslim görevi hatası bölümü.