TensorFlow.org에서보기 | Google Colab에서 실행하기 | GitHub에서 소스 보기 | 노트북 다운로드하기 |
내결함성은 매개변수 및 모델과 같은 추적 가능한 객체의 상태를 주기적으로 저장하는 메커니즘을 말합니다. 훈련하는 동안 프로그램/머신 오류가 발생한 경우 이를 사용하여 복구할 수 있습니다.
이 가이드에서는 먼저 tf.estimator.RunConfig
를 사용하여 메트릭 저장 설정을 지정하고 TensorFlow 1에서 tf.estimator.Estimator
를 사용하여 훈련에 내결함성을 추가하는 방법을 보여줍니다. 그런 다음 Tensorflow 2에서 훈련에 내결함성을 구현하는 방법 2가지를 배우게 됩니다.
- Keras
Model.fit
API를 사용하는 경우 해당 API로tf.keras.callbacks.BackupAndRestore
콜백을 전달할 수 있습니다. - 사용자 정의 훈련 루프(
tf.GradientTape
사용)를 사용하는 경우tf.train.Checkpoint
및tf.train.CheckpointManager
API를 사용하여 체크포인트를 임의로 저장할 수 있습니다.
이 두 가지 메서드 모두 체크포인트 파일의 훈련 상태를 백업하고 복원합니다.
설치하기
tf.keras.callbacks.BackupAndRestore
의 save_freq
인수를 사용하여 특정 단계에서 체크포인트의 빈도를 저장하는 기능이 TensorFlow 2.10부터 도입되었으므로 tf-nightly
를 설치합니다.
pip install tf-nightly
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
2022-12-14 20:52:01.602161: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
TensorFlow 1: tf.estimator.RunConfig
를 사용하여 체크포인트 저장하기
TensorFlow 1에서는 tf.estimator.RunConfig
를 구성하여 모든 단계마다 체크포인트를 저장하도록 tf.estimator
를 구성할 수 있습니다.
이 예제에서는 다섯 번째 체크포인트를 진행하는 동안 인위적으로 오류를 발생시키는 후크를 먼저 작성합니다.
class InterruptHook(tf1.train.SessionRunHook):
# A hook for artificially interrupting training.
def begin(self):
self._step = -1
def before_run(self, run_context):
self._step += 1
def after_run(self, run_context, run_values):
if self._step == 5:
raise RuntimeError('Interruption')
다음으로 모든 체크포인트를 저장하고 MNIST 데이터세트를 사용하도록 tf.estimator.Estimator
를 구성합니다.
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_63763/314197976.py:1: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version. Instructions for updating: Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_63763/314197976.py:2: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_63763/314197976.py:7: DNNClassifier.__init__ (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:807: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpi091tdq3', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/tmp/ipykernel_63763/314197976.py:17: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_63763/314197976.py:17: numpy_input_fn (from tensorflow_estimator.python.estimator.inputs.numpy_io) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead.
모델 훈련을 시작합니다. 앞에서 정의한 후크로 의해 인위적인 예외가 발생합니다.
try:
classifier.train(input_fn=train_input_fn,
hooks=[InterruptHook()],
max_steps=10)
except Exception as e:
print(f'{type(e).__name__}:{e}')
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_63763/2587623597.py:3: object.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:60: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:446: dnn_logit_fn_builder (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1414: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1417: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1454: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:910: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. 2022-12-14 20:52:08.500848: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT64 } } } is neither a subtype nor a supertype of the combined inputs preceding it: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT32 } } } while inferring type of node 'dnn/zero_fraction/cond/output/_18' INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpi091tdq3/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into /tmpfs/tmp/tmpi091tdq3/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:loss = 116.79694, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmpfs/tmp/tmpi091tdq3/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3... INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmpi091tdq3/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4... INFO:tensorflow:Saving checkpoints for 4 into /tmpfs/tmp/tmpi091tdq3/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5... INFO:tensorflow:Saving checkpoints for 5 into /tmpfs/tmp/tmpi091tdq3/model.ckpt. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/saver.py:1067: remove_checkpoint (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to delete files with this prefix. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmpfs/tmp/tmpi091tdq3/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... RuntimeError:Interruption
마지막으로 저장한 체크포인트를 사용하여 tf.estimator.Estimator
를 다시 빌드하고 훈련을 계속 진행합니다.
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
classifier.train(input_fn=train_input_fn,
max_steps = 10)
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpi091tdq3', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpi091tdq3/model.ckpt-6 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/saver.py:1176: get_checkpoint_mtimes (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file utilities to get mtimes. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmpfs/tmp/tmpi091tdq3/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7... INFO:tensorflow:Saving checkpoints for 7 into /tmpfs/tmp/tmpi091tdq3/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7... INFO:tensorflow:loss = 100.25842, step = 6 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8... INFO:tensorflow:Saving checkpoints for 8 into /tmpfs/tmp/tmpi091tdq3/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9... INFO:tensorflow:Saving checkpoints for 9 into /tmpfs/tmp/tmpi091tdq3/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmpfs/tmp/tmpi091tdq3/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 96.2075. <tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7f225bc54d30>
TensorFlow 2: 콜백 및 Model.fit
으로 백업 및 복원하기
TensorFlow 2에서는 훈련에 Keras Model.fit
API를 사용하는 경우 tf.keras.callbacks.BackupAndRestore
콜백을 제공하여 내결함성 기능을 추가할 수 있습니다.
이를 보여주기 위해 우선적으로 네 번째 epoch 체크포인트를 진행하는 동안 인위적으로 오류를 발생시키는 Keras Callback
클래스를 정의합니다.
class InterruptAtEpoch(tf.keras.callbacks.Callback):
# A callback for artificially interrupting training.
def __init__(self, interrupting_epoch=3):
self.interrupting_epoch = interrupting_epoch
def on_epoch_end(self, epoch, log=None):
if epoch == self.interrupting_epoch:
raise RuntimeError('Interruption')
그런 다음 간단한 Keras 모델을 정의 및 인스턴스화하고, 손실 함수를 정의하고, Model.compile
을 호출하고, epoch 경계에서 임시 디렉터리에 체크포인트를 저장하는 tf.keras.callbacks.BackupAndRestore
콜백을 설정합니다.
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'])
log_dir = tempfile.mkdtemp()
backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
backup_dir = log_dir)
Model.fit
을 사용하여 모델 훈련을 시작합니다. 훈련을 진행하는 동안 위에서 인스턴스화한 tf.keras.callbacks.BackupAndRestore
덕분에 체크포인트가 저장되지만 InterruptAtEpoch
클래스는 인위적으로 예외를 발생시켜 네 번째 epoch 이후에 실패를 시뮬레이션합니다.
try:
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback, InterruptAtEpoch()])
except Exception as e:
print(f'{type(e).__name__}:{e}')
Epoch 1/10 100/100 [==============================] - 2s 11ms/step - loss: 0.4660 - accuracy: 0.8693 - val_loss: 0.2196 - val_accuracy: 0.9391 Epoch 2/10 100/100 [==============================] - 1s 8ms/step - loss: 0.2022 - accuracy: 0.9430 - val_loss: 0.1582 - val_accuracy: 0.9549 Epoch 3/10 100/100 [==============================] - 1s 8ms/step - loss: 0.1475 - accuracy: 0.9580 - val_loss: 0.1253 - val_accuracy: 0.9629 Epoch 4/10 90/100 [==========================>...] - ETA: 0s - loss: 0.1174 - accuracy: 0.9661RuntimeError:Interruption
그런 다음 Keras 모델을 인스턴스화하고 Model.compile
을 호출한 다음 이전에 저장한 체크포인트의 Model.fit
을 사용하여 모델을 계속 훈련합니다.
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback])
Epoch 5/10 100/100 [==============================] - 2s 19ms/step - loss: 0.0947 - accuracy: 0.9733 - val_loss: 0.0896 - val_accuracy: 0.9731 Epoch 6/10 100/100 [==============================] - 1s 5ms/step - loss: 0.0812 - accuracy: 0.9769 - val_loss: 0.0824 - val_accuracy: 0.9761 Epoch 7/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0671 - accuracy: 0.9814 - val_loss: 0.0781 - val_accuracy: 0.9769 Epoch 8/10 100/100 [==============================] - 1s 5ms/step - loss: 0.0595 - accuracy: 0.9829 - val_loss: 0.0709 - val_accuracy: 0.9795 Epoch 9/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0516 - accuracy: 0.9850 - val_loss: 0.0734 - val_accuracy: 0.9779 Epoch 10/10 100/100 [==============================] - 1s 5ms/step - loss: 0.0469 - accuracy: 0.9866 - val_loss: 0.0683 - val_accuracy: 0.9792 <keras.callbacks.History at 0x7f217c299e20>
140번째 단계에서 인위적으로 오류를 발생시키는 다른 Callback
클래스를 정의합니다.
class InterruptAtStep(tf.keras.callbacks.Callback):
# A callback for artificially interrupting training.
def __init__(self, interrupting_step=140):
self.total_step_count = 0
self.interrupting_step = interrupting_step
def on_batch_begin(self, batch, logs=None):
self.total_step_count += 1
def on_batch_end(self, batch, logs=None):
if self.total_step_count == self.interrupting_step:
print("\nInterrupting at step count", self.total_step_count)
raise RuntimeError('Interruption')
참고: 이 섹션에서는 Tensorflow 2.10이 릴리스될 때까지 tf-nightly
에서만 사용할 수 있는 특성을 사용합니다.
체크포인트가 30단계마다 저장되도록 하려면 BackupAndRestore
콜백의 save_freq
를 30
으로 설정합니다. InterruptAtStep
이 epoch 1 및 40단계(총 단계 수 140)에서 실패를 시뮬레이션하기 위해 인위적으로 예외를 발생시킵니다. 체크포인트는 epoch 1과 20단계에서 마지막으로 저장될 것입니다.
log_dir_2 = tempfile.mkdtemp()
backup_restore_callback = tf.keras.callbacks.BackupAndRestore(
backup_dir = log_dir_2, save_freq=30
)
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'])
try:
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback, InterruptAtStep()])
except Exception as e:
print(f'{type(e).__name__}:{e}')
Epoch 1/10 100/100 [==============================] - 2s 11ms/step - loss: 0.4730 - accuracy: 0.8676 - val_loss: 0.2210 - val_accuracy: 0.9369 Epoch 2/10 37/100 [==========>...................] - ETA: 0s - loss: 0.2252 - accuracy: 0.9364 Interrupting at step count 140 RuntimeError:Interruption
그런 다음 Keras 모델을 인스턴스화하고 Model.compile
을 호출한 다음 이전에 저장한 체크포인트의 Model.fit
을 사용하여 모델을 계속 훈련합니다. 훈련은 epoch 2와 21단계부터 시작합니다.
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10)
model.fit(x=x_train,
y=y_train,
epochs=10,
steps_per_epoch=100,
validation_data=(x_test, y_test),
callbacks=[backup_restore_callback])
Epoch 2/10 100/100 [==============================] - 2s 18ms/step - loss: 0.1896 - accuracy: 0.9465 - val_loss: 0.1512 - val_accuracy: 0.9555 Epoch 3/10 100/100 [==============================] - 1s 5ms/step - loss: 0.1452 - accuracy: 0.9584 - val_loss: 0.1201 - val_accuracy: 0.9642 Epoch 4/10 100/100 [==============================] - 0s 5ms/step - loss: 0.1135 - accuracy: 0.9674 - val_loss: 0.0995 - val_accuracy: 0.9707 Epoch 5/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0899 - accuracy: 0.9737 - val_loss: 0.0900 - val_accuracy: 0.9722 Epoch 6/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0779 - accuracy: 0.9781 - val_loss: 0.0812 - val_accuracy: 0.9760 Epoch 7/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0654 - accuracy: 0.9813 - val_loss: 0.0718 - val_accuracy: 0.9774 Epoch 8/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0593 - accuracy: 0.9828 - val_loss: 0.0707 - val_accuracy: 0.9789 Epoch 9/10 100/100 [==============================] - 1s 5ms/step - loss: 0.0487 - accuracy: 0.9864 - val_loss: 0.0659 - val_accuracy: 0.9799 Epoch 10/10 100/100 [==============================] - 0s 5ms/step - loss: 0.0444 - accuracy: 0.9872 - val_loss: 0.0654 - val_accuracy: 0.9799 <keras.callbacks.History at 0x7f2230038430>
TensorFlow 2: 사용자 정의 훈련 루프를 사용하여 수동 체크포인트 작성하기
TensorFlow 2에서 사용자 정의 훈련 루프를 사용하는 경우 tf.train.Checkpoint
및 tf.train.CheckpointManager
API로 내결함성 메커니즘을 구현할 수 있습니다.
이 예제는 다음을 수행하는 방법을 보여줍니다.
- 저장하려는 추적 가능한 객체를 속성으로 설정한 체크포인트를 수동으로 생성하려면
tf.train.Checkpoint
객체를 사용합니다. - 여러 체크포인트를 관리하려면
tf.train.CheckpointManager
를 사용합니다.
먼저 Keras 모델, 옵티마이저, 손실 함수를 정의하고 인스턴스화합니다. 그런 다음 추적 가능한 상태가 있는 두 객체(모델 및 옵티마이저)를 관리하는 Checkpoint
와 임시 디렉터리에서 여러 체크포인트를 기록하고 유지하는 CheckpointManager
를 생성합니다.
model = create_model()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
log_dir = tempfile.mkdtemp()
epochs = 5
steps_per_epoch = 5
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, log_dir, max_to_keep=2)
이제 새 epoch가 시작될 때마다 첫 번째 epoch 이후 마지막 체크포인트를 로드하는 사용자 정의 훈련 루프를 구현합니다.
for epoch in range(epochs):
if epoch > 0:
tf.train.load_checkpoint(save_path)
print(f"\nStart of epoch {epoch}")
for step in range(steps_per_epoch):
with tf.GradientTape() as tape:
logits = model(x_train, training=True)
loss_value = loss_fn(y_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
save_path = checkpoint_manager.save()
print(f"Checkpoint saved to {save_path}")
print(f"Training loss at step {step}: {loss_value}")
Start of epoch 0 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-1 Training loss at step 0: 2.4224987030029297 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-2 Training loss at step 1: 2.422628402709961 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-3 Training loss at step 2: 2.4189400672912598 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-4 Training loss at step 3: 2.4165825843811035 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-5 Training loss at step 4: 2.4144229888916016 Start of epoch 1 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-6 Training loss at step 0: 2.4147567749023438 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-7 Training loss at step 1: 2.4123194217681885 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-8 Training loss at step 2: 2.410810708999634 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-9 Training loss at step 3: 2.4087791442871094 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-10 Training loss at step 4: 2.407498359680176 Start of epoch 2 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-11 Training loss at step 0: 2.4056396484375 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-12 Training loss at step 1: 2.4038097858428955 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-13 Training loss at step 2: 2.401495933532715 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-14 Training loss at step 3: 2.3997390270233154 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-15 Training loss at step 4: 2.397336959838867 Start of epoch 3 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-16 Training loss at step 0: 2.3974244594573975 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-17 Training loss at step 1: 2.394087076187134 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-18 Training loss at step 2: 2.393651008605957 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-19 Training loss at step 3: 2.3912947177886963 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-20 Training loss at step 4: 2.389580726623535 Start of epoch 4 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-21 Training loss at step 0: 2.388636350631714 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-22 Training loss at step 1: 2.386532783508301 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-23 Training loss at step 2: 2.3842995166778564 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-24 Training loss at step 3: 2.3836612701416016 Checkpoint saved to /tmpfs/tmp/tmpz95vb9ch/ckpt-25 Training loss at step 4: 2.3818771839141846
다음 단계
TensorFlow 2의 내결함성 및 체크포인트에 대해 자세히 알아보려면 다음 문서를 고려합니다.
tf.keras.callbacks.BackupAndRestore
콜백 API 설명서.tf.train.Checkpoint
및tf.train.CheckpointManager
API 설명서.- 체크포인트 작성 섹션 등 체크포인트 훈련하기 가이드.
분산 훈련과 관련된 다음 자료도 유용할 수 있습니다.
- Keras를 사용하는 다중 작업자 훈련 가이드의 내결함성 섹션.
- 매개변수 서버 훈련 가이드의 작업 실패 처리하기 섹션.