조기 중단 마이그레이션

TensorFlow.org에서보기 Google Colab에서 실행하기 GitHub에서 소스 보기 노트북 다운로드하기

이 노트북은 먼저 tf.estimator.Estimator 및 조기 중단 후크를 사용하여 TensorFlow 1에서 조기 중단하는 모델 훈련을 설정한 다음 Keras API 혹은 사용자 정의 훈련 루프를 사용하여 TensorFlow 2에서 모델 훈련을 설정하는 방법을 보여줍니다. 조기 중단은 예를 들어 검증 손실이 특정 임계값에 도달하면 훈련을 중지하는 정규화 기술입니다.

TensorFlow 2에는 조기 중단을 구현하는 세 가지 방법이 있습니다.

설치하기

import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
2022-12-14 20:12:26.232370: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:12:26.232462: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:12:26.232471: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

TensorFlow 1: 조기 중단 후크 및 tf.estimator를 사용하는 조기 중단

먼저 MNIST 데이터세트 로드 및 전처리용 함수와 tf.estimator.Estimator와 함께 사용할 모델 정의를 정의합니다.

def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def _input_fn():
  ds_train = tfds.load(
    name='mnist',
    split='train',
    shuffle_files=True,
    as_supervised=True)

  ds_train = ds_train.map(
      normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_train = ds_train.batch(128)
  ds_train = ds_train.repeat(100)
  return ds_train

def _eval_input_fn():
  ds_test = tfds.load(
    name='mnist',
    split='test',
    shuffle_files=True,
    as_supervised=True)
  ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_test = ds_test.batch(128)
  return ds_test

def _model_fn(features, labels, mode):
  flatten = tf1.layers.Flatten()(features)
  features = tf1.layers.Dense(128, 'relu')(flatten)
  logits = tf1.layers.Dense(10)(features)

  loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  optimizer = tf1.train.AdagradOptimizer(0.005)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())

  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

TensorFlow 1에서 조기 중단은 tf.estimator.experimental.make_early_stopping_hook으로 조기 중단 후크를 설정하면 작동합니다. 인수가 없어도 함수를 허용할 수 있는 should_stop_fn용 매개변수로써 make_early_stopping_hook 메서드에 후크를 전달합니다. should_stop_fnTrue를 반환하면 훈련이 중단됩니다.

다음 예제는 훈련 시간을 최대 20초로 제한하는 조기 중단 기술을 구현하는 방법을 보여줍니다.

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

start_time = time.time()
max_train_seconds = 20

def should_stop_fn():
  return time.time() - start_time > max_train_seconds

early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
    estimator=estimator,
    should_stop_fn=should_stop_fn,
    run_every_secs=1,
    run_every_steps=None)

train_spec = tf1.estimator.TrainSpec(
    input_fn=_input_fn,
    hooks=[early_stopping_hook])

eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)

tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpef3j9d0b
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpef3j9d0b', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpef3j9d0b/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpef3j9d0b/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.3648038, step = 0
INFO:tensorflow:loss = 2.3648038, step = 0
INFO:tensorflow:global_step/sec: 151.444
INFO:tensorflow:global_step/sec: 151.444
INFO:tensorflow:loss = 1.3138256, step = 100 (0.662 sec)
INFO:tensorflow:loss = 1.3138256, step = 100 (0.662 sec)
INFO:tensorflow:global_step/sec: 162.146
INFO:tensorflow:global_step/sec: 162.146
INFO:tensorflow:loss = 0.77403545, step = 200 (0.617 sec)
INFO:tensorflow:loss = 0.77403545, step = 200 (0.617 sec)
INFO:tensorflow:global_step/sec: 170.071
INFO:tensorflow:global_step/sec: 170.071
INFO:tensorflow:loss = 0.68814844, step = 300 (0.589 sec)
INFO:tensorflow:loss = 0.68814844, step = 300 (0.589 sec)
INFO:tensorflow:global_step/sec: 165.17
INFO:tensorflow:global_step/sec: 165.17
INFO:tensorflow:loss = 0.6460906, step = 400 (0.605 sec)
INFO:tensorflow:loss = 0.6460906, step = 400 (0.605 sec)
INFO:tensorflow:global_step/sec: 305.273
INFO:tensorflow:global_step/sec: 305.273
INFO:tensorflow:loss = 0.49146488, step = 500 (0.327 sec)
INFO:tensorflow:loss = 0.49146488, step = 500 (0.327 sec)
INFO:tensorflow:global_step/sec: 506.734
INFO:tensorflow:global_step/sec: 506.734
INFO:tensorflow:loss = 0.4392424, step = 600 (0.197 sec)
INFO:tensorflow:loss = 0.4392424, step = 600 (0.197 sec)
INFO:tensorflow:global_step/sec: 513.491
INFO:tensorflow:global_step/sec: 513.491
INFO:tensorflow:loss = 0.37929723, step = 700 (0.194 sec)
INFO:tensorflow:loss = 0.37929723, step = 700 (0.194 sec)
INFO:tensorflow:global_step/sec: 510.466
INFO:tensorflow:global_step/sec: 510.466
INFO:tensorflow:loss = 0.5218885, step = 800 (0.196 sec)
INFO:tensorflow:loss = 0.5218885, step = 800 (0.196 sec)
INFO:tensorflow:global_step/sec: 516.69
INFO:tensorflow:global_step/sec: 516.69
INFO:tensorflow:loss = 0.39052337, step = 900 (0.193 sec)
INFO:tensorflow:loss = 0.39052337, step = 900 (0.193 sec)
INFO:tensorflow:global_step/sec: 516.456
INFO:tensorflow:global_step/sec: 516.456
INFO:tensorflow:loss = 0.42972472, step = 1000 (0.194 sec)
INFO:tensorflow:loss = 0.42972472, step = 1000 (0.194 sec)
INFO:tensorflow:global_step/sec: 545.414
INFO:tensorflow:global_step/sec: 545.414
INFO:tensorflow:loss = 0.4583263, step = 1100 (0.183 sec)
INFO:tensorflow:loss = 0.4583263, step = 1100 (0.183 sec)
INFO:tensorflow:global_step/sec: 538.936
INFO:tensorflow:global_step/sec: 538.936
INFO:tensorflow:loss = 0.390843, step = 1200 (0.186 sec)
INFO:tensorflow:loss = 0.390843, step = 1200 (0.186 sec)
INFO:tensorflow:global_step/sec: 544.264
INFO:tensorflow:global_step/sec: 544.264
INFO:tensorflow:loss = 0.46664625, step = 1300 (0.184 sec)
INFO:tensorflow:loss = 0.46664625, step = 1300 (0.184 sec)
INFO:tensorflow:global_step/sec: 548.873
INFO:tensorflow:global_step/sec: 548.873
INFO:tensorflow:loss = 0.28304386, step = 1400 (0.182 sec)
INFO:tensorflow:loss = 0.28304386, step = 1400 (0.182 sec)
INFO:tensorflow:global_step/sec: 454.856
INFO:tensorflow:global_step/sec: 454.856
INFO:tensorflow:loss = 0.2891518, step = 1500 (0.220 sec)
INFO:tensorflow:loss = 0.2891518, step = 1500 (0.220 sec)
INFO:tensorflow:global_step/sec: 485.895
INFO:tensorflow:global_step/sec: 485.895
INFO:tensorflow:loss = 0.39843252, step = 1600 (0.207 sec)
INFO:tensorflow:loss = 0.39843252, step = 1600 (0.207 sec)
INFO:tensorflow:global_step/sec: 424.282
INFO:tensorflow:global_step/sec: 424.282
INFO:tensorflow:loss = 0.3823722, step = 1700 (0.236 sec)
INFO:tensorflow:loss = 0.3823722, step = 1700 (0.236 sec)
INFO:tensorflow:global_step/sec: 545.962
INFO:tensorflow:global_step/sec: 545.962
INFO:tensorflow:loss = 0.30911946, step = 1800 (0.182 sec)
INFO:tensorflow:loss = 0.30911946, step = 1800 (0.182 sec)
INFO:tensorflow:global_step/sec: 480.805
INFO:tensorflow:global_step/sec: 480.805
INFO:tensorflow:loss = 0.5225426, step = 1900 (0.208 sec)
INFO:tensorflow:loss = 0.5225426, step = 1900 (0.208 sec)
INFO:tensorflow:global_step/sec: 518.423
INFO:tensorflow:global_step/sec: 518.423
INFO:tensorflow:loss = 0.22036824, step = 2000 (0.193 sec)
INFO:tensorflow:loss = 0.22036824, step = 2000 (0.193 sec)
INFO:tensorflow:global_step/sec: 553.222
INFO:tensorflow:global_step/sec: 553.222
INFO:tensorflow:loss = 0.2787759, step = 2100 (0.181 sec)
INFO:tensorflow:loss = 0.2787759, step = 2100 (0.181 sec)
INFO:tensorflow:global_step/sec: 549.704
INFO:tensorflow:global_step/sec: 549.704
INFO:tensorflow:loss = 0.29477462, step = 2200 (0.183 sec)
INFO:tensorflow:loss = 0.29477462, step = 2200 (0.183 sec)
INFO:tensorflow:global_step/sec: 543.757
INFO:tensorflow:global_step/sec: 543.757
INFO:tensorflow:loss = 0.3432158, step = 2300 (0.184 sec)
INFO:tensorflow:loss = 0.3432158, step = 2300 (0.184 sec)
INFO:tensorflow:global_step/sec: 522.905
INFO:tensorflow:global_step/sec: 522.905
INFO:tensorflow:loss = 0.24653733, step = 2400 (0.191 sec)
INFO:tensorflow:loss = 0.24653733, step = 2400 (0.191 sec)
INFO:tensorflow:global_step/sec: 549.746
INFO:tensorflow:global_step/sec: 549.746
INFO:tensorflow:loss = 0.22215012, step = 2500 (0.182 sec)
INFO:tensorflow:loss = 0.22215012, step = 2500 (0.182 sec)
INFO:tensorflow:global_step/sec: 544.876
INFO:tensorflow:global_step/sec: 544.876
INFO:tensorflow:loss = 0.14438769, step = 2600 (0.184 sec)
INFO:tensorflow:loss = 0.14438769, step = 2600 (0.184 sec)
INFO:tensorflow:global_step/sec: 550.019
INFO:tensorflow:global_step/sec: 550.019
INFO:tensorflow:loss = 0.29830757, step = 2700 (0.182 sec)
INFO:tensorflow:loss = 0.29830757, step = 2700 (0.182 sec)
INFO:tensorflow:global_step/sec: 548.555
INFO:tensorflow:global_step/sec: 548.555
INFO:tensorflow:loss = 0.47932947, step = 2800 (0.182 sec)
INFO:tensorflow:loss = 0.47932947, step = 2800 (0.182 sec)
INFO:tensorflow:global_step/sec: 494.071
INFO:tensorflow:global_step/sec: 494.071
INFO:tensorflow:loss = 0.24411874, step = 2900 (0.202 sec)
INFO:tensorflow:loss = 0.24411874, step = 2900 (0.202 sec)
INFO:tensorflow:global_step/sec: 521.425
INFO:tensorflow:global_step/sec: 521.425
INFO:tensorflow:loss = 0.3316818, step = 3000 (0.192 sec)
INFO:tensorflow:loss = 0.3316818, step = 3000 (0.192 sec)
INFO:tensorflow:global_step/sec: 532.014
INFO:tensorflow:global_step/sec: 532.014
INFO:tensorflow:loss = 0.20751971, step = 3100 (0.188 sec)
INFO:tensorflow:loss = 0.20751971, step = 3100 (0.188 sec)
INFO:tensorflow:global_step/sec: 522.784
INFO:tensorflow:global_step/sec: 522.784
INFO:tensorflow:loss = 0.4179578, step = 3200 (0.191 sec)
INFO:tensorflow:loss = 0.4179578, step = 3200 (0.191 sec)
INFO:tensorflow:global_step/sec: 490.206
INFO:tensorflow:global_step/sec: 490.206
INFO:tensorflow:loss = 0.33943966, step = 3300 (0.204 sec)
INFO:tensorflow:loss = 0.33943966, step = 3300 (0.204 sec)
INFO:tensorflow:global_step/sec: 561.505
INFO:tensorflow:global_step/sec: 561.505
INFO:tensorflow:loss = 0.26698294, step = 3400 (0.179 sec)
INFO:tensorflow:loss = 0.26698294, step = 3400 (0.179 sec)
INFO:tensorflow:global_step/sec: 563.807
INFO:tensorflow:global_step/sec: 563.807
INFO:tensorflow:loss = 0.19843455, step = 3500 (0.177 sec)
INFO:tensorflow:loss = 0.19843455, step = 3500 (0.177 sec)
INFO:tensorflow:global_step/sec: 558.398
INFO:tensorflow:global_step/sec: 558.398
INFO:tensorflow:loss = 0.22747274, step = 3600 (0.178 sec)
INFO:tensorflow:loss = 0.22747274, step = 3600 (0.178 sec)
INFO:tensorflow:global_step/sec: 552.048
INFO:tensorflow:global_step/sec: 552.048
INFO:tensorflow:loss = 0.26726133, step = 3700 (0.181 sec)
INFO:tensorflow:loss = 0.26726133, step = 3700 (0.181 sec)
INFO:tensorflow:global_step/sec: 523.826
INFO:tensorflow:global_step/sec: 523.826
INFO:tensorflow:loss = 0.37733287, step = 3800 (0.191 sec)
INFO:tensorflow:loss = 0.37733287, step = 3800 (0.191 sec)
INFO:tensorflow:global_step/sec: 536.735
INFO:tensorflow:global_step/sec: 536.735
INFO:tensorflow:loss = 0.21054713, step = 3900 (0.187 sec)
INFO:tensorflow:loss = 0.21054713, step = 3900 (0.187 sec)
INFO:tensorflow:global_step/sec: 536.298
INFO:tensorflow:global_step/sec: 536.298
INFO:tensorflow:loss = 0.3076029, step = 4000 (0.186 sec)
INFO:tensorflow:loss = 0.3076029, step = 4000 (0.186 sec)
INFO:tensorflow:global_step/sec: 538.71
INFO:tensorflow:global_step/sec: 538.71
INFO:tensorflow:loss = 0.21131831, step = 4100 (0.186 sec)
INFO:tensorflow:loss = 0.21131831, step = 4100 (0.186 sec)
INFO:tensorflow:global_step/sec: 536.263
INFO:tensorflow:global_step/sec: 536.263
INFO:tensorflow:loss = 0.2575843, step = 4200 (0.186 sec)
INFO:tensorflow:loss = 0.2575843, step = 4200 (0.186 sec)
INFO:tensorflow:global_step/sec: 488.188
INFO:tensorflow:global_step/sec: 488.188
INFO:tensorflow:loss = 0.30193213, step = 4300 (0.205 sec)
INFO:tensorflow:loss = 0.30193213, step = 4300 (0.205 sec)
INFO:tensorflow:global_step/sec: 514.547
INFO:tensorflow:global_step/sec: 514.547
INFO:tensorflow:loss = 0.30978817, step = 4400 (0.194 sec)
INFO:tensorflow:loss = 0.30978817, step = 4400 (0.194 sec)
INFO:tensorflow:global_step/sec: 518.293
INFO:tensorflow:global_step/sec: 518.293
INFO:tensorflow:loss = 0.2733007, step = 4500 (0.193 sec)
INFO:tensorflow:loss = 0.2733007, step = 4500 (0.193 sec)
INFO:tensorflow:global_step/sec: 521.221
INFO:tensorflow:global_step/sec: 521.221
INFO:tensorflow:loss = 0.3213666, step = 4600 (0.192 sec)
INFO:tensorflow:loss = 0.3213666, step = 4600 (0.192 sec)
INFO:tensorflow:global_step/sec: 514.958
INFO:tensorflow:global_step/sec: 514.958
INFO:tensorflow:loss = 0.14884642, step = 4700 (0.195 sec)
INFO:tensorflow:loss = 0.14884642, step = 4700 (0.195 sec)
INFO:tensorflow:global_step/sec: 448.422
INFO:tensorflow:global_step/sec: 448.422
INFO:tensorflow:loss = 0.31162184, step = 4800 (0.223 sec)
INFO:tensorflow:loss = 0.31162184, step = 4800 (0.223 sec)
INFO:tensorflow:global_step/sec: 493.985
INFO:tensorflow:global_step/sec: 493.985
INFO:tensorflow:loss = 0.37944964, step = 4900 (0.203 sec)
INFO:tensorflow:loss = 0.37944964, step = 4900 (0.203 sec)
INFO:tensorflow:global_step/sec: 551.664
INFO:tensorflow:global_step/sec: 551.664
INFO:tensorflow:loss = 0.2854257, step = 5000 (0.182 sec)
INFO:tensorflow:loss = 0.2854257, step = 5000 (0.182 sec)
INFO:tensorflow:global_step/sec: 547.726
INFO:tensorflow:global_step/sec: 547.726
INFO:tensorflow:loss = 0.32582244, step = 5100 (0.182 sec)
INFO:tensorflow:loss = 0.32582244, step = 5100 (0.182 sec)
INFO:tensorflow:global_step/sec: 444.448
INFO:tensorflow:global_step/sec: 444.448
INFO:tensorflow:loss = 0.20980233, step = 5200 (0.225 sec)
INFO:tensorflow:loss = 0.20980233, step = 5200 (0.225 sec)
INFO:tensorflow:global_step/sec: 423.918
INFO:tensorflow:global_step/sec: 423.918
INFO:tensorflow:loss = 0.24278633, step = 5300 (0.236 sec)
INFO:tensorflow:loss = 0.24278633, step = 5300 (0.236 sec)
INFO:tensorflow:global_step/sec: 441.673
INFO:tensorflow:global_step/sec: 441.673
INFO:tensorflow:loss = 0.16715293, step = 5400 (0.227 sec)
INFO:tensorflow:loss = 0.16715293, step = 5400 (0.227 sec)
INFO:tensorflow:global_step/sec: 546.928
INFO:tensorflow:global_step/sec: 546.928
INFO:tensorflow:loss = 0.2231634, step = 5500 (0.183 sec)
INFO:tensorflow:loss = 0.2231634, step = 5500 (0.183 sec)
INFO:tensorflow:global_step/sec: 429.178
INFO:tensorflow:global_step/sec: 429.178
INFO:tensorflow:loss = 0.1799423, step = 5600 (0.233 sec)
INFO:tensorflow:loss = 0.1799423, step = 5600 (0.233 sec)
INFO:tensorflow:global_step/sec: 437.64
INFO:tensorflow:global_step/sec: 437.64
INFO:tensorflow:loss = 0.16861719, step = 5700 (0.228 sec)
INFO:tensorflow:loss = 0.16861719, step = 5700 (0.228 sec)
INFO:tensorflow:Requesting early stopping at global step 5737
INFO:tensorflow:Requesting early stopping at global step 5737
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5738...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5738...
INFO:tensorflow:Saving checkpoints for 5738 into /tmpfs/tmp/tmpef3j9d0b/model.ckpt.
INFO:tensorflow:Saving checkpoints for 5738 into /tmpfs/tmp/tmpef3j9d0b/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5738...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5738...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-14T20:12:48
INFO:tensorflow:Starting evaluation at 2022-12-14T20:12:48
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpef3j9d0b/model.ckpt-5738
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpef3j9d0b/model.ckpt-5738
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Inference Time : 1.02504s
INFO:tensorflow:Inference Time : 1.02504s
INFO:tensorflow:Finished evaluation at 2022-12-14-20:12:49
INFO:tensorflow:Finished evaluation at 2022-12-14-20:12:49
INFO:tensorflow:Saving dict for global step 5738: global_step = 5738, loss = 0.24106143
INFO:tensorflow:Saving dict for global step 5738: global_step = 5738, loss = 0.24106143
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5738: /tmpfs/tmp/tmpef3j9d0b/model.ckpt-5738
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5738: /tmpfs/tmp/tmpef3j9d0b/model.ckpt-5738
INFO:tensorflow:Loss for final step: 0.33733124.
INFO:tensorflow:Loss for final step: 0.33733124.
({'loss': 0.24106143, 'global_step': 5738}, [])

TensorFlow 2: 내장 콜백 및 Model.fit을 사용하는 조기 중단

MNIST 데이터세트 및 간단한 Keras 모델 준비:

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.005),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

TensorFlow 2에서 내장 Keras Model.fit(혹은 Model.evaluate)을 사용하는 경우 내장 콜백인 tf.keras.callbacks.EarlyStoppingModel.fitcallbacks 매개변수로 전달함으로써 조기 중단을 구성할 수 있습니다.

EarlyStopping 콜백은 사용자가 지정한 메트릭을 모니터링하고 개선이 중단되면 훈련을 종료합니다(자세한 정보는 내장 메서드를 사용하는 훈련 및 평가 혹은 API 문서를 확인하세요).

다음은 개선을 보여주지 않는 epoch의 수를 3(patience)으로 설정한 후 손실을 모니터링하며 훈련을 중단하는 조기 중단 콜백의 예제입니다.

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)

# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)

len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 4s 5ms/step - loss: 0.2317 - sparse_categorical_accuracy: 0.9306 - val_loss: 0.1307 - val_sparse_categorical_accuracy: 0.9584
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0992 - sparse_categorical_accuracy: 0.9705 - val_loss: 0.1059 - val_sparse_categorical_accuracy: 0.9660
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9790 - val_loss: 0.1004 - val_sparse_categorical_accuracy: 0.9698
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0544 - sparse_categorical_accuracy: 0.9827 - val_loss: 0.1033 - val_sparse_categorical_accuracy: 0.9697
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0434 - sparse_categorical_accuracy: 0.9862 - val_loss: 0.1268 - val_sparse_categorical_accuracy: 0.9648
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0346 - sparse_categorical_accuracy: 0.9883 - val_loss: 0.1027 - val_sparse_categorical_accuracy: 0.9749
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0310 - sparse_categorical_accuracy: 0.9900 - val_loss: 0.1099 - val_sparse_categorical_accuracy: 0.9714
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0284 - sparse_categorical_accuracy: 0.9904 - val_loss: 0.1170 - val_sparse_categorical_accuracy: 0.9726
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0263 - sparse_categorical_accuracy: 0.9908 - val_loss: 0.1171 - val_sparse_categorical_accuracy: 0.9737
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0246 - sparse_categorical_accuracy: 0.9914 - val_loss: 0.1372 - val_sparse_categorical_accuracy: 0.9725
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0272 - sparse_categorical_accuracy: 0.9913 - val_loss: 0.1291 - val_sparse_categorical_accuracy: 0.9758
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0188 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.1330 - val_sparse_categorical_accuracy: 0.9763
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0180 - sparse_categorical_accuracy: 0.9942 - val_loss: 0.1739 - val_sparse_categorical_accuracy: 0.9679
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0237 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1382 - val_sparse_categorical_accuracy: 0.9756
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0185 - sparse_categorical_accuracy: 0.9939 - val_loss: 0.1342 - val_sparse_categorical_accuracy: 0.9770
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0208 - sparse_categorical_accuracy: 0.9938 - val_loss: 0.1450 - val_sparse_categorical_accuracy: 0.9761
16

TensorFlow 2: 사용자 정의 콜백 및 Model.fit을 사용하는 조기 중단

사용자 정의 조기 중단 콜백을 구현할 수도 있습니다. 이 콜백은 Model.fit(혹은 Model.evaluate)의 callbacks 매개변수로 전달할 수도 있습니다.

이 예제에서는 self.model.stop_trainingTrue로 설정하면 훈련 프로세스가 중단됩니다.

class LimitTrainingTime(tf.keras.callbacks.Callback):
  def __init__(self, max_time_s):
    super().__init__()
    self.max_time_s = max_time_s
    self.start_time = None

  def on_train_begin(self, logs):
    self.start_time = time.time()

  def on_train_batch_end(self, batch, logs):
    now = time.time()
    if now - self.start_time >  self.max_time_s:
      self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0145 - sparse_categorical_accuracy: 0.9954 - val_loss: 0.1421 - val_sparse_categorical_accuracy: 0.9760
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.1660 - val_sparse_categorical_accuracy: 0.9753
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0178 - sparse_categorical_accuracy: 0.9946 - val_loss: 0.1581 - val_sparse_categorical_accuracy: 0.9767
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0196 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1733 - val_sparse_categorical_accuracy: 0.9760
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0173 - sparse_categorical_accuracy: 0.9946 - val_loss: 0.1816 - val_sparse_categorical_accuracy: 0.9765
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0192 - sparse_categorical_accuracy: 0.9945 - val_loss: 0.1859 - val_sparse_categorical_accuracy: 0.9775
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0151 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1713 - val_sparse_categorical_accuracy: 0.9770
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0155 - sparse_categorical_accuracy: 0.9954 - val_loss: 0.1759 - val_sparse_categorical_accuracy: 0.9763
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0171 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1969 - val_sparse_categorical_accuracy: 0.9767
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0126 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.1819 - val_sparse_categorical_accuracy: 0.9774
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0132 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.1877 - val_sparse_categorical_accuracy: 0.9775
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0130 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.1904 - val_sparse_categorical_accuracy: 0.9766
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1869 - val_sparse_categorical_accuracy: 0.9788
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.1906 - val_sparse_categorical_accuracy: 0.9781
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0134 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.1851 - val_sparse_categorical_accuracy: 0.9805
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.1918 - val_sparse_categorical_accuracy: 0.9799
Epoch 17/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0090 - sparse_categorical_accuracy: 0.9973 - val_loss: 0.2095 - val_sparse_categorical_accuracy: 0.9752
Epoch 18/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0123 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.2479 - val_sparse_categorical_accuracy: 0.9760
Epoch 19/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0142 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2056 - val_sparse_categorical_accuracy: 0.9791
Epoch 20/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0112 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.2276 - val_sparse_categorical_accuracy: 0.9763
Epoch 21/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0121 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.2612 - val_sparse_categorical_accuracy: 0.9755
Epoch 22/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0130 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.2596 - val_sparse_categorical_accuracy: 0.9737
Epoch 23/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2826 - val_sparse_categorical_accuracy: 0.9754
23

TensorFlow 2: 사용자 정의 훈련 루프를 사용하는 조기 중단

TensorFlow 2에서는 내장 Keras 메서드로 훈련과 평가를 수행하지 않은 경우 사용자 정의 훈련 루프에서 조기 중단을 구현할 수 있습니다.

먼저 Keras API를 사용하여 다른 간단한 모델, 옵티마이저, 손실 함수 및 메트릭을 정의합니다.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()

tf.GradientTape속도 향상을 위한 @tf.function 데코레이터를 사용하여 매개변수 업데이트 함수를 정의합니다.

@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
      logits = model(x, training=True)
      loss_value = loss_fn(y, logits)
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
  train_acc_metric.update_state(y, logits)
  train_loss_metric.update_state(y, logits)
  return loss_value

@tf.function
def test_step(x, y):
  logits = model(x, training=False)
  val_acc_metric.update_state(y, logits)
  val_loss_metric.update_state(y, logits)

다음으로 조기 중단 규칙을 수동으로 구현할 수 있는 사용자 정의 훈련 루프를 작성합니다.

아래의 예제는 검증 손실이 특정 epoch 수 동안 개선되지 않을 경우 훈련을 중단하는 방식을 보여줍니다.

epochs = 100
patience = 5
wait = 0
best = 0

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
      loss_value = train_step(x_batch_train, y_batch_train)
      if step % 200 == 0:
        print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
        print("Seen so far: %s samples" % ((step + 1) * 128))        
    train_acc = train_acc_metric.result()
    train_loss = train_loss_metric.result()
    train_acc_metric.reset_states()
    train_loss_metric.reset_states()
    print("Training acc over epoch: %.4f" % (train_acc.numpy()))

    for x_batch_val, y_batch_val in ds_test:
      test_step(x_batch_val, y_batch_val)
    val_acc = val_acc_metric.result()
    val_loss = val_loss_metric.result()
    val_acc_metric.reset_states()
    val_loss_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

    # The early stopping strategy: stop the training if `val_loss` does not
    # decrease over a certain number of epochs.
    wait += 1
    if val_loss > best:
      best = val_loss
      wait = 0
    if wait >= patience:
      break
Start of epoch 0
Training loss at step 0: 2.3333
Seen so far: 128 samples
Training loss at step 200: 0.2772
Seen so far: 25728 samples
Training loss at step 400: 0.2551
Seen so far: 51328 samples
Training acc over epoch: 0.9306
Validation acc: 0.9608
Time taken: 1.99s

Start of epoch 1
Training loss at step 0: 0.0677
Seen so far: 128 samples
Training loss at step 200: 0.1620
Seen so far: 25728 samples
Training loss at step 400: 0.1523
Seen so far: 51328 samples
Training acc over epoch: 0.9692
Validation acc: 0.9648
Time taken: 1.08s

Start of epoch 2
Training loss at step 0: 0.0360
Seen so far: 128 samples
Training loss at step 200: 0.0929
Seen so far: 25728 samples
Training loss at step 400: 0.1308
Seen so far: 51328 samples
Training acc over epoch: 0.9795
Validation acc: 0.9681
Time taken: 1.03s

Start of epoch 3
Training loss at step 0: 0.0199
Seen so far: 128 samples
Training loss at step 200: 0.0391
Seen so far: 25728 samples
Training loss at step 400: 0.0789
Seen so far: 51328 samples
Training acc over epoch: 0.9836
Validation acc: 0.9713
Time taken: 1.11s

Start of epoch 4
Training loss at step 0: 0.0194
Seen so far: 128 samples
Training loss at step 200: 0.0289
Seen so far: 25728 samples
Training loss at step 400: 0.0561
Seen so far: 51328 samples
Training acc over epoch: 0.9864
Validation acc: 0.9734
Time taken: 1.06s

Start of epoch 5
Training loss at step 0: 0.0175
Seen so far: 128 samples
Training loss at step 200: 0.0540
Seen so far: 25728 samples
Training loss at step 400: 0.0402
Seen so far: 51328 samples
Training acc over epoch: 0.9873
Validation acc: 0.9702
Time taken: 1.10s

다음 단계