TensorFlow.org에서보기 | Google Colab에서 실행하기 | GitHub에서 소스 보기 | 노트북 다운로드하기 |
이 노트북은 먼저 tf.estimator.Estimator
및 조기 중단 후크를 사용하여 TensorFlow 1에서 조기 중단하는 모델 훈련을 설정한 다음 Keras API 혹은 사용자 정의 훈련 루프를 사용하여 TensorFlow 2에서 모델 훈련을 설정하는 방법을 보여줍니다. 조기 중단은 예를 들어 검증 손실이 특정 임계값에 도달하면 훈련을 중지하는 정규화 기술입니다.
TensorFlow 2에는 조기 중단을 구현하는 세 가지 방법이 있습니다.
- 내장 Keras 콜백(
tf.keras.callbacks.EarlyStopping
)을 사용하고 이를Model.fit
에 전달합니다. - 사용자 정의 콜백을 정의하고 이를 Keras
Model.fit
에 전달합니다. - 사용자 정의 훈련 루프(
tf.GradientTape
사용)에서 사용자 정의 조기 중단 규칙을 작성합니다.
설치하기
import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
2022-12-14 20:12:26.232370: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 20:12:26.232462: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 20:12:26.232471: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
TensorFlow 1: 조기 중단 후크 및 tf.estimator를 사용하는 조기 중단
먼저 MNIST 데이터세트 로드 및 전처리용 함수와 tf.estimator.Estimator
와 함께 사용할 모델 정의를 정의합니다.
def normalize_img(image, label):
return tf.cast(image, tf.float32) / 255., label
def _input_fn():
ds_train = tfds.load(
name='mnist',
split='train',
shuffle_files=True,
as_supervised=True)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_train = ds_train.repeat(100)
return ds_train
def _eval_input_fn():
ds_test = tfds.load(
name='mnist',
split='test',
shuffle_files=True,
as_supervised=True)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
return ds_test
def _model_fn(features, labels, mode):
flatten = tf1.layers.Flatten()(features)
features = tf1.layers.Dense(128, 'relu')(flatten)
logits = tf1.layers.Dense(10)(features)
loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
optimizer = tf1.train.AdagradOptimizer(0.005)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
TensorFlow 1에서 조기 중단은 tf.estimator.experimental.make_early_stopping_hook
으로 조기 중단 후크를 설정하면 작동합니다. 인수가 없어도 함수를 허용할 수 있는 should_stop_fn
용 매개변수로써 make_early_stopping_hook
메서드에 후크를 전달합니다. should_stop_fn
이 True
를 반환하면 훈련이 중단됩니다.
다음 예제는 훈련 시간을 최대 20초로 제한하는 조기 중단 기술을 구현하는 방법을 보여줍니다.
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
start_time = time.time()
max_train_seconds = 20
def should_stop_fn():
return time.time() - start_time > max_train_seconds
early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
estimator=estimator,
should_stop_fn=should_stop_fn,
run_every_secs=1,
run_every_steps=None)
train_spec = tf1.estimator.TrainSpec(
input_fn=_input_fn,
hooks=[early_stopping_hook])
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpef3j9d0b INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpef3j9d0b', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpef3j9d0b/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpef3j9d0b/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 2.3648038, step = 0 INFO:tensorflow:loss = 2.3648038, step = 0 INFO:tensorflow:global_step/sec: 151.444 INFO:tensorflow:global_step/sec: 151.444 INFO:tensorflow:loss = 1.3138256, step = 100 (0.662 sec) INFO:tensorflow:loss = 1.3138256, step = 100 (0.662 sec) INFO:tensorflow:global_step/sec: 162.146 INFO:tensorflow:global_step/sec: 162.146 INFO:tensorflow:loss = 0.77403545, step = 200 (0.617 sec) INFO:tensorflow:loss = 0.77403545, step = 200 (0.617 sec) INFO:tensorflow:global_step/sec: 170.071 INFO:tensorflow:global_step/sec: 170.071 INFO:tensorflow:loss = 0.68814844, step = 300 (0.589 sec) INFO:tensorflow:loss = 0.68814844, step = 300 (0.589 sec) INFO:tensorflow:global_step/sec: 165.17 INFO:tensorflow:global_step/sec: 165.17 INFO:tensorflow:loss = 0.6460906, step = 400 (0.605 sec) INFO:tensorflow:loss = 0.6460906, step = 400 (0.605 sec) INFO:tensorflow:global_step/sec: 305.273 INFO:tensorflow:global_step/sec: 305.273 INFO:tensorflow:loss = 0.49146488, step = 500 (0.327 sec) INFO:tensorflow:loss = 0.49146488, step = 500 (0.327 sec) INFO:tensorflow:global_step/sec: 506.734 INFO:tensorflow:global_step/sec: 506.734 INFO:tensorflow:loss = 0.4392424, step = 600 (0.197 sec) INFO:tensorflow:loss = 0.4392424, step = 600 (0.197 sec) INFO:tensorflow:global_step/sec: 513.491 INFO:tensorflow:global_step/sec: 513.491 INFO:tensorflow:loss = 0.37929723, step = 700 (0.194 sec) INFO:tensorflow:loss = 0.37929723, step = 700 (0.194 sec) INFO:tensorflow:global_step/sec: 510.466 INFO:tensorflow:global_step/sec: 510.466 INFO:tensorflow:loss = 0.5218885, step = 800 (0.196 sec) INFO:tensorflow:loss = 0.5218885, step = 800 (0.196 sec) INFO:tensorflow:global_step/sec: 516.69 INFO:tensorflow:global_step/sec: 516.69 INFO:tensorflow:loss = 0.39052337, step = 900 (0.193 sec) INFO:tensorflow:loss = 0.39052337, step = 900 (0.193 sec) INFO:tensorflow:global_step/sec: 516.456 INFO:tensorflow:global_step/sec: 516.456 INFO:tensorflow:loss = 0.42972472, step = 1000 (0.194 sec) INFO:tensorflow:loss = 0.42972472, step = 1000 (0.194 sec) INFO:tensorflow:global_step/sec: 545.414 INFO:tensorflow:global_step/sec: 545.414 INFO:tensorflow:loss = 0.4583263, step = 1100 (0.183 sec) INFO:tensorflow:loss = 0.4583263, step = 1100 (0.183 sec) INFO:tensorflow:global_step/sec: 538.936 INFO:tensorflow:global_step/sec: 538.936 INFO:tensorflow:loss = 0.390843, step = 1200 (0.186 sec) INFO:tensorflow:loss = 0.390843, step = 1200 (0.186 sec) INFO:tensorflow:global_step/sec: 544.264 INFO:tensorflow:global_step/sec: 544.264 INFO:tensorflow:loss = 0.46664625, step = 1300 (0.184 sec) INFO:tensorflow:loss = 0.46664625, step = 1300 (0.184 sec) INFO:tensorflow:global_step/sec: 548.873 INFO:tensorflow:global_step/sec: 548.873 INFO:tensorflow:loss = 0.28304386, step = 1400 (0.182 sec) INFO:tensorflow:loss = 0.28304386, step = 1400 (0.182 sec) INFO:tensorflow:global_step/sec: 454.856 INFO:tensorflow:global_step/sec: 454.856 INFO:tensorflow:loss = 0.2891518, step = 1500 (0.220 sec) INFO:tensorflow:loss = 0.2891518, step = 1500 (0.220 sec) INFO:tensorflow:global_step/sec: 485.895 INFO:tensorflow:global_step/sec: 485.895 INFO:tensorflow:loss = 0.39843252, step = 1600 (0.207 sec) INFO:tensorflow:loss = 0.39843252, step = 1600 (0.207 sec) INFO:tensorflow:global_step/sec: 424.282 INFO:tensorflow:global_step/sec: 424.282 INFO:tensorflow:loss = 0.3823722, step = 1700 (0.236 sec) INFO:tensorflow:loss = 0.3823722, step = 1700 (0.236 sec) INFO:tensorflow:global_step/sec: 545.962 INFO:tensorflow:global_step/sec: 545.962 INFO:tensorflow:loss = 0.30911946, step = 1800 (0.182 sec) INFO:tensorflow:loss = 0.30911946, step = 1800 (0.182 sec) INFO:tensorflow:global_step/sec: 480.805 INFO:tensorflow:global_step/sec: 480.805 INFO:tensorflow:loss = 0.5225426, step = 1900 (0.208 sec) INFO:tensorflow:loss = 0.5225426, step = 1900 (0.208 sec) INFO:tensorflow:global_step/sec: 518.423 INFO:tensorflow:global_step/sec: 518.423 INFO:tensorflow:loss = 0.22036824, step = 2000 (0.193 sec) INFO:tensorflow:loss = 0.22036824, step = 2000 (0.193 sec) INFO:tensorflow:global_step/sec: 553.222 INFO:tensorflow:global_step/sec: 553.222 INFO:tensorflow:loss = 0.2787759, step = 2100 (0.181 sec) INFO:tensorflow:loss = 0.2787759, step = 2100 (0.181 sec) INFO:tensorflow:global_step/sec: 549.704 INFO:tensorflow:global_step/sec: 549.704 INFO:tensorflow:loss = 0.29477462, step = 2200 (0.183 sec) INFO:tensorflow:loss = 0.29477462, step = 2200 (0.183 sec) INFO:tensorflow:global_step/sec: 543.757 INFO:tensorflow:global_step/sec: 543.757 INFO:tensorflow:loss = 0.3432158, step = 2300 (0.184 sec) INFO:tensorflow:loss = 0.3432158, step = 2300 (0.184 sec) INFO:tensorflow:global_step/sec: 522.905 INFO:tensorflow:global_step/sec: 522.905 INFO:tensorflow:loss = 0.24653733, step = 2400 (0.191 sec) INFO:tensorflow:loss = 0.24653733, step = 2400 (0.191 sec) INFO:tensorflow:global_step/sec: 549.746 INFO:tensorflow:global_step/sec: 549.746 INFO:tensorflow:loss = 0.22215012, step = 2500 (0.182 sec) INFO:tensorflow:loss = 0.22215012, step = 2500 (0.182 sec) INFO:tensorflow:global_step/sec: 544.876 INFO:tensorflow:global_step/sec: 544.876 INFO:tensorflow:loss = 0.14438769, step = 2600 (0.184 sec) INFO:tensorflow:loss = 0.14438769, step = 2600 (0.184 sec) INFO:tensorflow:global_step/sec: 550.019 INFO:tensorflow:global_step/sec: 550.019 INFO:tensorflow:loss = 0.29830757, step = 2700 (0.182 sec) INFO:tensorflow:loss = 0.29830757, step = 2700 (0.182 sec) INFO:tensorflow:global_step/sec: 548.555 INFO:tensorflow:global_step/sec: 548.555 INFO:tensorflow:loss = 0.47932947, step = 2800 (0.182 sec) INFO:tensorflow:loss = 0.47932947, step = 2800 (0.182 sec) INFO:tensorflow:global_step/sec: 494.071 INFO:tensorflow:global_step/sec: 494.071 INFO:tensorflow:loss = 0.24411874, step = 2900 (0.202 sec) INFO:tensorflow:loss = 0.24411874, step = 2900 (0.202 sec) INFO:tensorflow:global_step/sec: 521.425 INFO:tensorflow:global_step/sec: 521.425 INFO:tensorflow:loss = 0.3316818, step = 3000 (0.192 sec) INFO:tensorflow:loss = 0.3316818, step = 3000 (0.192 sec) INFO:tensorflow:global_step/sec: 532.014 INFO:tensorflow:global_step/sec: 532.014 INFO:tensorflow:loss = 0.20751971, step = 3100 (0.188 sec) INFO:tensorflow:loss = 0.20751971, step = 3100 (0.188 sec) INFO:tensorflow:global_step/sec: 522.784 INFO:tensorflow:global_step/sec: 522.784 INFO:tensorflow:loss = 0.4179578, step = 3200 (0.191 sec) INFO:tensorflow:loss = 0.4179578, step = 3200 (0.191 sec) INFO:tensorflow:global_step/sec: 490.206 INFO:tensorflow:global_step/sec: 490.206 INFO:tensorflow:loss = 0.33943966, step = 3300 (0.204 sec) INFO:tensorflow:loss = 0.33943966, step = 3300 (0.204 sec) INFO:tensorflow:global_step/sec: 561.505 INFO:tensorflow:global_step/sec: 561.505 INFO:tensorflow:loss = 0.26698294, step = 3400 (0.179 sec) INFO:tensorflow:loss = 0.26698294, step = 3400 (0.179 sec) INFO:tensorflow:global_step/sec: 563.807 INFO:tensorflow:global_step/sec: 563.807 INFO:tensorflow:loss = 0.19843455, step = 3500 (0.177 sec) INFO:tensorflow:loss = 0.19843455, step = 3500 (0.177 sec) INFO:tensorflow:global_step/sec: 558.398 INFO:tensorflow:global_step/sec: 558.398 INFO:tensorflow:loss = 0.22747274, step = 3600 (0.178 sec) INFO:tensorflow:loss = 0.22747274, step = 3600 (0.178 sec) INFO:tensorflow:global_step/sec: 552.048 INFO:tensorflow:global_step/sec: 552.048 INFO:tensorflow:loss = 0.26726133, step = 3700 (0.181 sec) INFO:tensorflow:loss = 0.26726133, step = 3700 (0.181 sec) INFO:tensorflow:global_step/sec: 523.826 INFO:tensorflow:global_step/sec: 523.826 INFO:tensorflow:loss = 0.37733287, step = 3800 (0.191 sec) INFO:tensorflow:loss = 0.37733287, step = 3800 (0.191 sec) INFO:tensorflow:global_step/sec: 536.735 INFO:tensorflow:global_step/sec: 536.735 INFO:tensorflow:loss = 0.21054713, step = 3900 (0.187 sec) INFO:tensorflow:loss = 0.21054713, step = 3900 (0.187 sec) INFO:tensorflow:global_step/sec: 536.298 INFO:tensorflow:global_step/sec: 536.298 INFO:tensorflow:loss = 0.3076029, step = 4000 (0.186 sec) INFO:tensorflow:loss = 0.3076029, step = 4000 (0.186 sec) INFO:tensorflow:global_step/sec: 538.71 INFO:tensorflow:global_step/sec: 538.71 INFO:tensorflow:loss = 0.21131831, step = 4100 (0.186 sec) INFO:tensorflow:loss = 0.21131831, step = 4100 (0.186 sec) INFO:tensorflow:global_step/sec: 536.263 INFO:tensorflow:global_step/sec: 536.263 INFO:tensorflow:loss = 0.2575843, step = 4200 (0.186 sec) INFO:tensorflow:loss = 0.2575843, step = 4200 (0.186 sec) INFO:tensorflow:global_step/sec: 488.188 INFO:tensorflow:global_step/sec: 488.188 INFO:tensorflow:loss = 0.30193213, step = 4300 (0.205 sec) INFO:tensorflow:loss = 0.30193213, step = 4300 (0.205 sec) INFO:tensorflow:global_step/sec: 514.547 INFO:tensorflow:global_step/sec: 514.547 INFO:tensorflow:loss = 0.30978817, step = 4400 (0.194 sec) INFO:tensorflow:loss = 0.30978817, step = 4400 (0.194 sec) INFO:tensorflow:global_step/sec: 518.293 INFO:tensorflow:global_step/sec: 518.293 INFO:tensorflow:loss = 0.2733007, step = 4500 (0.193 sec) INFO:tensorflow:loss = 0.2733007, step = 4500 (0.193 sec) INFO:tensorflow:global_step/sec: 521.221 INFO:tensorflow:global_step/sec: 521.221 INFO:tensorflow:loss = 0.3213666, step = 4600 (0.192 sec) INFO:tensorflow:loss = 0.3213666, step = 4600 (0.192 sec) INFO:tensorflow:global_step/sec: 514.958 INFO:tensorflow:global_step/sec: 514.958 INFO:tensorflow:loss = 0.14884642, step = 4700 (0.195 sec) INFO:tensorflow:loss = 0.14884642, step = 4700 (0.195 sec) INFO:tensorflow:global_step/sec: 448.422 INFO:tensorflow:global_step/sec: 448.422 INFO:tensorflow:loss = 0.31162184, step = 4800 (0.223 sec) INFO:tensorflow:loss = 0.31162184, step = 4800 (0.223 sec) INFO:tensorflow:global_step/sec: 493.985 INFO:tensorflow:global_step/sec: 493.985 INFO:tensorflow:loss = 0.37944964, step = 4900 (0.203 sec) INFO:tensorflow:loss = 0.37944964, step = 4900 (0.203 sec) INFO:tensorflow:global_step/sec: 551.664 INFO:tensorflow:global_step/sec: 551.664 INFO:tensorflow:loss = 0.2854257, step = 5000 (0.182 sec) INFO:tensorflow:loss = 0.2854257, step = 5000 (0.182 sec) INFO:tensorflow:global_step/sec: 547.726 INFO:tensorflow:global_step/sec: 547.726 INFO:tensorflow:loss = 0.32582244, step = 5100 (0.182 sec) INFO:tensorflow:loss = 0.32582244, step = 5100 (0.182 sec) INFO:tensorflow:global_step/sec: 444.448 INFO:tensorflow:global_step/sec: 444.448 INFO:tensorflow:loss = 0.20980233, step = 5200 (0.225 sec) INFO:tensorflow:loss = 0.20980233, step = 5200 (0.225 sec) INFO:tensorflow:global_step/sec: 423.918 INFO:tensorflow:global_step/sec: 423.918 INFO:tensorflow:loss = 0.24278633, step = 5300 (0.236 sec) INFO:tensorflow:loss = 0.24278633, step = 5300 (0.236 sec) INFO:tensorflow:global_step/sec: 441.673 INFO:tensorflow:global_step/sec: 441.673 INFO:tensorflow:loss = 0.16715293, step = 5400 (0.227 sec) INFO:tensorflow:loss = 0.16715293, step = 5400 (0.227 sec) INFO:tensorflow:global_step/sec: 546.928 INFO:tensorflow:global_step/sec: 546.928 INFO:tensorflow:loss = 0.2231634, step = 5500 (0.183 sec) INFO:tensorflow:loss = 0.2231634, step = 5500 (0.183 sec) INFO:tensorflow:global_step/sec: 429.178 INFO:tensorflow:global_step/sec: 429.178 INFO:tensorflow:loss = 0.1799423, step = 5600 (0.233 sec) INFO:tensorflow:loss = 0.1799423, step = 5600 (0.233 sec) INFO:tensorflow:global_step/sec: 437.64 INFO:tensorflow:global_step/sec: 437.64 INFO:tensorflow:loss = 0.16861719, step = 5700 (0.228 sec) INFO:tensorflow:loss = 0.16861719, step = 5700 (0.228 sec) INFO:tensorflow:Requesting early stopping at global step 5737 INFO:tensorflow:Requesting early stopping at global step 5737 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5738... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5738... INFO:tensorflow:Saving checkpoints for 5738 into /tmpfs/tmp/tmpef3j9d0b/model.ckpt. INFO:tensorflow:Saving checkpoints for 5738 into /tmpfs/tmp/tmpef3j9d0b/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5738... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5738... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T20:12:48 INFO:tensorflow:Starting evaluation at 2022-12-14T20:12:48 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpef3j9d0b/model.ckpt-5738 INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpef3j9d0b/model.ckpt-5738 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Inference Time : 1.02504s INFO:tensorflow:Inference Time : 1.02504s INFO:tensorflow:Finished evaluation at 2022-12-14-20:12:49 INFO:tensorflow:Finished evaluation at 2022-12-14-20:12:49 INFO:tensorflow:Saving dict for global step 5738: global_step = 5738, loss = 0.24106143 INFO:tensorflow:Saving dict for global step 5738: global_step = 5738, loss = 0.24106143 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5738: /tmpfs/tmp/tmpef3j9d0b/model.ckpt-5738 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5738: /tmpfs/tmp/tmpef3j9d0b/model.ckpt-5738 INFO:tensorflow:Loss for final step: 0.33733124. INFO:tensorflow:Loss for final step: 0.33733124. ({'loss': 0.24106143, 'global_step': 5738}, [])
TensorFlow 2: 내장 콜백 및 Model.fit을 사용하는 조기 중단
MNIST 데이터세트 및 간단한 Keras 모델 준비:
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.005),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
TensorFlow 2에서 내장 Keras Model.fit
(혹은 Model.evaluate
)을 사용하는 경우 내장 콜백인 tf.keras.callbacks.EarlyStopping
을 Model.fit
의 callbacks
매개변수로 전달함으로써 조기 중단을 구성할 수 있습니다.
EarlyStopping
콜백은 사용자가 지정한 메트릭을 모니터링하고 개선이 중단되면 훈련을 종료합니다(자세한 정보는 내장 메서드를 사용하는 훈련 및 평가 혹은 API 문서를 확인하세요).
다음은 개선을 보여주지 않는 epoch의 수를 3
(patience
)으로 설정한 후 손실을 모니터링하며 훈련을 중단하는 조기 중단 콜백의 예제입니다.
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 4s 5ms/step - loss: 0.2317 - sparse_categorical_accuracy: 0.9306 - val_loss: 0.1307 - val_sparse_categorical_accuracy: 0.9584 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0992 - sparse_categorical_accuracy: 0.9705 - val_loss: 0.1059 - val_sparse_categorical_accuracy: 0.9660 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9790 - val_loss: 0.1004 - val_sparse_categorical_accuracy: 0.9698 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0544 - sparse_categorical_accuracy: 0.9827 - val_loss: 0.1033 - val_sparse_categorical_accuracy: 0.9697 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0434 - sparse_categorical_accuracy: 0.9862 - val_loss: 0.1268 - val_sparse_categorical_accuracy: 0.9648 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0346 - sparse_categorical_accuracy: 0.9883 - val_loss: 0.1027 - val_sparse_categorical_accuracy: 0.9749 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0310 - sparse_categorical_accuracy: 0.9900 - val_loss: 0.1099 - val_sparse_categorical_accuracy: 0.9714 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0284 - sparse_categorical_accuracy: 0.9904 - val_loss: 0.1170 - val_sparse_categorical_accuracy: 0.9726 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0263 - sparse_categorical_accuracy: 0.9908 - val_loss: 0.1171 - val_sparse_categorical_accuracy: 0.9737 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0246 - sparse_categorical_accuracy: 0.9914 - val_loss: 0.1372 - val_sparse_categorical_accuracy: 0.9725 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0272 - sparse_categorical_accuracy: 0.9913 - val_loss: 0.1291 - val_sparse_categorical_accuracy: 0.9758 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0188 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.1330 - val_sparse_categorical_accuracy: 0.9763 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0180 - sparse_categorical_accuracy: 0.9942 - val_loss: 0.1739 - val_sparse_categorical_accuracy: 0.9679 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0237 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1382 - val_sparse_categorical_accuracy: 0.9756 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0185 - sparse_categorical_accuracy: 0.9939 - val_loss: 0.1342 - val_sparse_categorical_accuracy: 0.9770 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0208 - sparse_categorical_accuracy: 0.9938 - val_loss: 0.1450 - val_sparse_categorical_accuracy: 0.9761 16
TensorFlow 2: 사용자 정의 콜백 및 Model.fit을 사용하는 조기 중단
사용자 정의 조기 중단 콜백을 구현할 수도 있습니다. 이 콜백은 Model.fit
(혹은 Model.evaluate
)의 callbacks
매개변수로 전달할 수도 있습니다.
이 예제에서는 self.model.stop_training
을 True
로 설정하면 훈련 프로세스가 중단됩니다.
class LimitTrainingTime(tf.keras.callbacks.Callback):
def __init__(self, max_time_s):
super().__init__()
self.max_time_s = max_time_s
self.start_time = None
def on_train_begin(self, logs):
self.start_time = time.time()
def on_train_batch_end(self, batch, logs):
now = time.time()
if now - self.start_time > self.max_time_s:
self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0145 - sparse_categorical_accuracy: 0.9954 - val_loss: 0.1421 - val_sparse_categorical_accuracy: 0.9760 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.1660 - val_sparse_categorical_accuracy: 0.9753 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0178 - sparse_categorical_accuracy: 0.9946 - val_loss: 0.1581 - val_sparse_categorical_accuracy: 0.9767 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0196 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1733 - val_sparse_categorical_accuracy: 0.9760 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0173 - sparse_categorical_accuracy: 0.9946 - val_loss: 0.1816 - val_sparse_categorical_accuracy: 0.9765 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0192 - sparse_categorical_accuracy: 0.9945 - val_loss: 0.1859 - val_sparse_categorical_accuracy: 0.9775 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0151 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1713 - val_sparse_categorical_accuracy: 0.9770 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0155 - sparse_categorical_accuracy: 0.9954 - val_loss: 0.1759 - val_sparse_categorical_accuracy: 0.9763 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0171 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1969 - val_sparse_categorical_accuracy: 0.9767 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0126 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.1819 - val_sparse_categorical_accuracy: 0.9774 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0132 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.1877 - val_sparse_categorical_accuracy: 0.9775 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0130 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.1904 - val_sparse_categorical_accuracy: 0.9766 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1869 - val_sparse_categorical_accuracy: 0.9788 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.1906 - val_sparse_categorical_accuracy: 0.9781 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0134 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.1851 - val_sparse_categorical_accuracy: 0.9805 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.1918 - val_sparse_categorical_accuracy: 0.9799 Epoch 17/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0090 - sparse_categorical_accuracy: 0.9973 - val_loss: 0.2095 - val_sparse_categorical_accuracy: 0.9752 Epoch 18/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0123 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.2479 - val_sparse_categorical_accuracy: 0.9760 Epoch 19/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0142 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2056 - val_sparse_categorical_accuracy: 0.9791 Epoch 20/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0112 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.2276 - val_sparse_categorical_accuracy: 0.9763 Epoch 21/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0121 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.2612 - val_sparse_categorical_accuracy: 0.9755 Epoch 22/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0130 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.2596 - val_sparse_categorical_accuracy: 0.9737 Epoch 23/100 469/469 [==============================] - 1s 2ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2826 - val_sparse_categorical_accuracy: 0.9754 23
TensorFlow 2: 사용자 정의 훈련 루프를 사용하는 조기 중단
TensorFlow 2에서는 내장 Keras 메서드로 훈련과 평가를 수행하지 않은 경우 사용자 정의 훈련 루프에서 조기 중단을 구현할 수 있습니다.
먼저 Keras API를 사용하여 다른 간단한 모델, 옵티마이저, 손실 함수 및 메트릭을 정의합니다.
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
tf.GradientTape와 속도 향상을 위한 @tf.function
데코레이터를 사용하여 매개변수 업데이트 함수를 정의합니다.
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
train_loss_metric.update_state(y, logits)
return loss_value
@tf.function
def test_step(x, y):
logits = model(x, training=False)
val_acc_metric.update_state(y, logits)
val_loss_metric.update_state(y, logits)
다음으로 조기 중단 규칙을 수동으로 구현할 수 있는 사용자 정의 훈련 루프를 작성합니다.
아래의 예제는 검증 손실이 특정 epoch 수 동안 개선되지 않을 경우 훈련을 중단하는 방식을 보여줍니다.
epochs = 100
patience = 5
wait = 0
best = 0
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
loss_value = train_step(x_batch_train, y_batch_train)
if step % 200 == 0:
print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
print("Seen so far: %s samples" % ((step + 1) * 128))
train_acc = train_acc_metric.result()
train_loss = train_loss_metric.result()
train_acc_metric.reset_states()
train_loss_metric.reset_states()
print("Training acc over epoch: %.4f" % (train_acc.numpy()))
for x_batch_val, y_batch_val in ds_test:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_loss = val_loss_metric.result()
val_acc_metric.reset_states()
val_loss_metric.reset_states()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
# The early stopping strategy: stop the training if `val_loss` does not
# decrease over a certain number of epochs.
wait += 1
if val_loss > best:
best = val_loss
wait = 0
if wait >= patience:
break
Start of epoch 0 Training loss at step 0: 2.3333 Seen so far: 128 samples Training loss at step 200: 0.2772 Seen so far: 25728 samples Training loss at step 400: 0.2551 Seen so far: 51328 samples Training acc over epoch: 0.9306 Validation acc: 0.9608 Time taken: 1.99s Start of epoch 1 Training loss at step 0: 0.0677 Seen so far: 128 samples Training loss at step 200: 0.1620 Seen so far: 25728 samples Training loss at step 400: 0.1523 Seen so far: 51328 samples Training acc over epoch: 0.9692 Validation acc: 0.9648 Time taken: 1.08s Start of epoch 2 Training loss at step 0: 0.0360 Seen so far: 128 samples Training loss at step 200: 0.0929 Seen so far: 25728 samples Training loss at step 400: 0.1308 Seen so far: 51328 samples Training acc over epoch: 0.9795 Validation acc: 0.9681 Time taken: 1.03s Start of epoch 3 Training loss at step 0: 0.0199 Seen so far: 128 samples Training loss at step 200: 0.0391 Seen so far: 25728 samples Training loss at step 400: 0.0789 Seen so far: 51328 samples Training acc over epoch: 0.9836 Validation acc: 0.9713 Time taken: 1.11s Start of epoch 4 Training loss at step 0: 0.0194 Seen so far: 128 samples Training loss at step 200: 0.0289 Seen so far: 25728 samples Training loss at step 400: 0.0561 Seen so far: 51328 samples Training acc over epoch: 0.9864 Validation acc: 0.9734 Time taken: 1.06s Start of epoch 5 Training loss at step 0: 0.0175 Seen so far: 128 samples Training loss at step 200: 0.0540 Seen so far: 25728 samples Training loss at step 400: 0.0402 Seen so far: 51328 samples Training acc over epoch: 0.9873 Validation acc: 0.9702 Time taken: 1.10s
다음 단계
- Keras 내장 조기 중단 콜백 API에 대해 API 문서에서 자세히 알아보세요.
- 최소 손실 시 조기 중단하기 등 사용자 정의 Keras 콜백 작성 방법을 알아보세요.
- Keras 내장 메서드를 사용하여 훈련 및 평가하기에 대해 알아보세요.
EarlyStopping
콜백을 사용하는 과대적합 및 과소적합 가이드에서 일반적인 정규화 기술을 살펴보세요.