Посмотреть на TensorFlow.org | Запустить в Google Colab | Посмотреть исходный код на GitHub | Скачать блокнот |
В этой записной книжке показано, как настроить обучение модели с ранней остановкой, сначала в TensorFlow 1 с tf.estimator.Estimator
и обработчиком ранней остановки, а затем в TensorFlow 2 с API-интерфейсами Keras или пользовательским циклом обучения. Ранняя остановка — это метод регуляризации, который останавливает обучение, если, например, потеря проверки достигает определенного порога.
В TensorFlow 2 есть три способа реализовать раннюю остановку:
- Используйте встроенный обратный вызов
tf.keras.callbacks.EarlyStopping
— tf.keras.callbacks.EarlyStopping — и передайте его вModel.fit
. - Определите пользовательский обратный вызов и передайте его в
Model.fit
. - Напишите пользовательское правило ранней остановки в пользовательском цикле обучения (с помощью
tf.GradientTape
).
Настраивать
import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
TensorFlow 1: ранняя остановка с хуком ранней остановки и tf.estimator
Начните с определения функций для загрузки и предварительной обработки набора данных MNIST, а также определения модели, которая будет использоваться с tf.estimator.Estimator
:
def normalize_img(image, label):
return tf.cast(image, tf.float32) / 255., label
def _input_fn():
ds_train = tfds.load(
name='mnist',
split='train',
shuffle_files=True,
as_supervised=True)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_train = ds_train.repeat(100)
return ds_train
def _eval_input_fn():
ds_test = tfds.load(
name='mnist',
split='test',
shuffle_files=True,
as_supervised=True)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
return ds_test
def _model_fn(features, labels, mode):
flatten = tf1.layers.Flatten()(features)
features = tf1.layers.Dense(128, 'relu')(flatten)
logits = tf1.layers.Dense(10)(features)
loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
optimizer = tf1.train.AdagradOptimizer(0.005)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
В TensorFlow 1 ранняя остановка работает путем настройки хука ранней остановки с помощью tf.estimator.experimental.make_early_stopping_hook
. Вы передаете хук в метод make_early_stopping_hook
в качестве параметра для should_stop_fn
, который может принимать функцию без каких-либо аргументов. Обучение останавливается, как только should_stop_fn
возвращает True
.
В следующем примере показано, как реализовать метод ранней остановки, который ограничивает время обучения максимум 20 секундами:
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
start_time = time.time()
max_train_seconds = 20
def should_stop_fn():
return time.time() - start_time > max_train_seconds
early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
estimator=estimator,
should_stop_fn=should_stop_fn,
run_every_secs=1,
run_every_steps=None)
train_spec = tf1.estimator.TrainSpec(
input_fn=_input_fn,
hooks=[early_stopping_hook])
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpocmc6_bo INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpocmc6_bo', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpocmc6_bo/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpocmc6_bo/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 2.3545606, step = 0 INFO:tensorflow:loss = 2.3545606, step = 0 INFO:tensorflow:global_step/sec: 94.5711 INFO:tensorflow:global_step/sec: 94.5711 INFO:tensorflow:loss = 1.3383636, step = 100 (1.060 sec) INFO:tensorflow:loss = 1.3383636, step = 100 (1.060 sec) INFO:tensorflow:global_step/sec: 158.428 INFO:tensorflow:global_step/sec: 158.428 INFO:tensorflow:loss = 0.7937969, step = 200 (0.631 sec) INFO:tensorflow:loss = 0.7937969, step = 200 (0.631 sec) INFO:tensorflow:global_step/sec: 287.334 INFO:tensorflow:global_step/sec: 287.334 INFO:tensorflow:loss = 0.69060934, step = 300 (0.349 sec) INFO:tensorflow:loss = 0.69060934, step = 300 (0.349 sec) INFO:tensorflow:global_step/sec: 286.658 INFO:tensorflow:global_step/sec: 286.658 INFO:tensorflow:loss = 0.59314424, step = 400 (0.349 sec) INFO:tensorflow:loss = 0.59314424, step = 400 (0.349 sec) INFO:tensorflow:global_step/sec: 311.591 INFO:tensorflow:global_step/sec: 311.591 INFO:tensorflow:loss = 0.50495726, step = 500 (0.320 sec) INFO:tensorflow:loss = 0.50495726, step = 500 (0.320 sec) WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 536 vs previous value: 536. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize. WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 536 vs previous value: 536. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize. INFO:tensorflow:global_step/sec: 538.395 INFO:tensorflow:global_step/sec: 538.395 INFO:tensorflow:loss = 0.43083754, step = 600 (0.186 sec) INFO:tensorflow:loss = 0.43083754, step = 600 (0.186 sec) INFO:tensorflow:global_step/sec: 503.72 INFO:tensorflow:global_step/sec: 503.72 INFO:tensorflow:loss = 0.381118, step = 700 (0.198 sec) INFO:tensorflow:loss = 0.381118, step = 700 (0.198 sec) WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 715 vs previous value: 715. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize. WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 715 vs previous value: 715. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize. INFO:tensorflow:global_step/sec: 482.019 INFO:tensorflow:global_step/sec: 482.019 INFO:tensorflow:loss = 0.49349022, step = 800 (0.207 sec) INFO:tensorflow:loss = 0.49349022, step = 800 (0.207 sec) INFO:tensorflow:global_step/sec: 508.316 INFO:tensorflow:global_step/sec: 508.316 INFO:tensorflow:loss = 0.38730466, step = 900 (0.199 sec) INFO:tensorflow:loss = 0.38730466, step = 900 (0.199 sec) WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 987 vs previous value: 987. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize. WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 987 vs previous value: 987. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize. INFO:tensorflow:global_step/sec: 452.89 INFO:tensorflow:global_step/sec: 452.89 INFO:tensorflow:loss = 0.44916487, step = 1000 (0.219 sec) INFO:tensorflow:loss = 0.44916487, step = 1000 (0.219 sec) WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1042 vs previous value: 1042. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize. WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1042 vs previous value: 1042. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize. INFO:tensorflow:global_step/sec: 519.401 INFO:tensorflow:global_step/sec: 519.401 INFO:tensorflow:loss = 0.44320562, step = 1100 (0.192 sec) INFO:tensorflow:loss = 0.44320562, step = 1100 (0.192 sec) INFO:tensorflow:global_step/sec: 510.25 INFO:tensorflow:global_step/sec: 510.25 INFO:tensorflow:loss = 0.3758085, step = 1200 (0.196 sec) INFO:tensorflow:loss = 0.3758085, step = 1200 (0.196 sec) INFO:tensorflow:global_step/sec: 518.649 INFO:tensorflow:global_step/sec: 518.649 INFO:tensorflow:loss = 0.46760654, step = 1300 (0.193 sec) INFO:tensorflow:loss = 0.46760654, step = 1300 (0.193 sec) INFO:tensorflow:global_step/sec: 474.056 INFO:tensorflow:global_step/sec: 474.056 INFO:tensorflow:loss = 0.29544568, step = 1400 (0.211 sec) INFO:tensorflow:loss = 0.29544568, step = 1400 (0.211 sec) INFO:tensorflow:global_step/sec: 461.406 INFO:tensorflow:global_step/sec: 461.406 INFO:tensorflow:loss = 0.28616875, step = 1500 (0.217 sec) INFO:tensorflow:loss = 0.28616875, step = 1500 (0.217 sec) INFO:tensorflow:global_step/sec: 486.2 INFO:tensorflow:global_step/sec: 486.2 INFO:tensorflow:loss = 0.4114887, step = 1600 (0.206 sec) INFO:tensorflow:loss = 0.4114887, step = 1600 (0.206 sec) WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1678 vs previous value: 1678. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize. WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1678 vs previous value: 1678. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize. INFO:tensorflow:global_step/sec: 507.701 INFO:tensorflow:global_step/sec: 507.701 INFO:tensorflow:loss = 0.35298553, step = 1700 (0.197 sec) INFO:tensorflow:loss = 0.35298553, step = 1700 (0.197 sec) INFO:tensorflow:global_step/sec: 490.541 INFO:tensorflow:global_step/sec: 490.541 INFO:tensorflow:loss = 0.3363277, step = 1800 (0.204 sec) INFO:tensorflow:loss = 0.3363277, step = 1800 (0.204 sec) INFO:tensorflow:global_step/sec: 460.083 INFO:tensorflow:global_step/sec: 460.083 INFO:tensorflow:loss = 0.50634325, step = 1900 (0.217 sec) INFO:tensorflow:loss = 0.50634325, step = 1900 (0.217 sec) INFO:tensorflow:global_step/sec: 436.782 INFO:tensorflow:global_step/sec: 436.782 INFO:tensorflow:loss = 0.2063987, step = 2000 (0.229 sec) INFO:tensorflow:loss = 0.2063987, step = 2000 (0.229 sec) INFO:tensorflow:global_step/sec: 475.841 INFO:tensorflow:global_step/sec: 475.841 INFO:tensorflow:loss = 0.27246287, step = 2100 (0.210 sec) INFO:tensorflow:loss = 0.27246287, step = 2100 (0.210 sec) INFO:tensorflow:global_step/sec: 483.322 INFO:tensorflow:global_step/sec: 483.322 INFO:tensorflow:loss = 0.31674564, step = 2200 (0.207 sec) INFO:tensorflow:loss = 0.31674564, step = 2200 (0.207 sec) INFO:tensorflow:global_step/sec: 442.257 INFO:tensorflow:global_step/sec: 442.257 INFO:tensorflow:loss = 0.3334998, step = 2300 (0.226 sec) INFO:tensorflow:loss = 0.3334998, step = 2300 (0.226 sec) INFO:tensorflow:global_step/sec: 476.38 INFO:tensorflow:global_step/sec: 476.38 INFO:tensorflow:loss = 0.2549953, step = 2400 (0.210 sec) INFO:tensorflow:loss = 0.2549953, step = 2400 (0.210 sec) INFO:tensorflow:global_step/sec: 467.543 INFO:tensorflow:global_step/sec: 467.543 INFO:tensorflow:loss = 0.21111101, step = 2500 (0.214 sec) INFO:tensorflow:loss = 0.21111101, step = 2500 (0.214 sec) INFO:tensorflow:global_step/sec: 497.051 INFO:tensorflow:global_step/sec: 497.051 INFO:tensorflow:loss = 0.15878338, step = 2600 (0.201 sec) INFO:tensorflow:loss = 0.15878338, step = 2600 (0.201 sec) INFO:tensorflow:global_step/sec: 461.785 INFO:tensorflow:global_step/sec: 461.785 INFO:tensorflow:loss = 0.31587577, step = 2700 (0.219 sec) INFO:tensorflow:loss = 0.31587577, step = 2700 (0.219 sec) INFO:tensorflow:global_step/sec: 493.743 INFO:tensorflow:global_step/sec: 493.743 INFO:tensorflow:loss = 0.47478187, step = 2800 (0.200 sec) INFO:tensorflow:loss = 0.47478187, step = 2800 (0.200 sec) INFO:tensorflow:global_step/sec: 463.477 INFO:tensorflow:global_step/sec: 463.477 INFO:tensorflow:loss = 0.2499526, step = 2900 (0.216 sec) INFO:tensorflow:loss = 0.2499526, step = 2900 (0.216 sec) INFO:tensorflow:global_step/sec: 538.27 INFO:tensorflow:global_step/sec: 538.27 INFO:tensorflow:loss = 0.34210858, step = 3000 (0.186 sec) INFO:tensorflow:loss = 0.34210858, step = 3000 (0.186 sec) INFO:tensorflow:global_step/sec: 508.741 INFO:tensorflow:global_step/sec: 508.741 INFO:tensorflow:loss = 0.2128592, step = 3100 (0.197 sec) INFO:tensorflow:loss = 0.2128592, step = 3100 (0.197 sec) INFO:tensorflow:global_step/sec: 519.319 INFO:tensorflow:global_step/sec: 519.319 INFO:tensorflow:loss = 0.40954083, step = 3200 (0.192 sec) INFO:tensorflow:loss = 0.40954083, step = 3200 (0.192 sec) INFO:tensorflow:global_step/sec: 468.989 INFO:tensorflow:global_step/sec: 468.989 INFO:tensorflow:loss = 0.34270883, step = 3300 (0.213 sec) INFO:tensorflow:loss = 0.34270883, step = 3300 (0.213 sec) INFO:tensorflow:global_step/sec: 479.856 INFO:tensorflow:global_step/sec: 479.856 INFO:tensorflow:loss = 0.26599607, step = 3400 (0.209 sec) INFO:tensorflow:loss = 0.26599607, step = 3400 (0.209 sec) INFO:tensorflow:global_step/sec: 495.76 INFO:tensorflow:global_step/sec: 495.76 INFO:tensorflow:loss = 0.21713805, step = 3500 (0.201 sec) INFO:tensorflow:loss = 0.21713805, step = 3500 (0.201 sec) INFO:tensorflow:global_step/sec: 440.282 INFO:tensorflow:global_step/sec: 440.282 INFO:tensorflow:loss = 0.22268976, step = 3600 (0.228 sec) INFO:tensorflow:loss = 0.22268976, step = 3600 (0.228 sec) INFO:tensorflow:global_step/sec: 495.629 INFO:tensorflow:global_step/sec: 495.629 INFO:tensorflow:loss = 0.28974164, step = 3700 (0.201 sec) INFO:tensorflow:loss = 0.28974164, step = 3700 (0.201 sec) INFO:tensorflow:global_step/sec: 468.695 INFO:tensorflow:global_step/sec: 468.695 INFO:tensorflow:loss = 0.37919793, step = 3800 (0.214 sec) INFO:tensorflow:loss = 0.37919793, step = 3800 (0.214 sec) INFO:tensorflow:global_step/sec: 529.005 INFO:tensorflow:global_step/sec: 529.005 INFO:tensorflow:loss = 0.23738712, step = 3900 (0.189 sec) INFO:tensorflow:loss = 0.23738712, step = 3900 (0.189 sec) INFO:tensorflow:global_step/sec: 494.809 INFO:tensorflow:global_step/sec: 494.809 INFO:tensorflow:loss = 0.29650036, step = 4000 (0.204 sec) INFO:tensorflow:loss = 0.29650036, step = 4000 (0.204 sec) INFO:tensorflow:global_step/sec: 525.629 INFO:tensorflow:global_step/sec: 525.629 INFO:tensorflow:loss = 0.20826155, step = 4100 (0.188 sec) INFO:tensorflow:loss = 0.20826155, step = 4100 (0.188 sec) INFO:tensorflow:global_step/sec: 509.573 INFO:tensorflow:global_step/sec: 509.573 INFO:tensorflow:loss = 0.26417816, step = 4200 (0.196 sec) INFO:tensorflow:loss = 0.26417816, step = 4200 (0.196 sec) INFO:tensorflow:global_step/sec: 472.845 INFO:tensorflow:global_step/sec: 472.845 INFO:tensorflow:loss = 0.31241363, step = 4300 (0.212 sec) INFO:tensorflow:loss = 0.31241363, step = 4300 (0.212 sec) INFO:tensorflow:global_step/sec: 510.868 INFO:tensorflow:global_step/sec: 510.868 INFO:tensorflow:loss = 0.32773697, step = 4400 (0.195 sec) INFO:tensorflow:loss = 0.32773697, step = 4400 (0.195 sec) INFO:tensorflow:global_step/sec: 492.967 INFO:tensorflow:global_step/sec: 492.967 INFO:tensorflow:loss = 0.28609803, step = 4500 (0.203 sec) INFO:tensorflow:loss = 0.28609803, step = 4500 (0.203 sec) INFO:tensorflow:global_step/sec: 507.394 INFO:tensorflow:global_step/sec: 507.394 INFO:tensorflow:loss = 0.32142323, step = 4600 (0.197 sec) INFO:tensorflow:loss = 0.32142323, step = 4600 (0.197 sec) INFO:tensorflow:global_step/sec: 475.176 INFO:tensorflow:global_step/sec: 475.176 INFO:tensorflow:loss = 0.14882785, step = 4700 (0.211 sec) INFO:tensorflow:loss = 0.14882785, step = 4700 (0.211 sec) INFO:tensorflow:global_step/sec: 503.718 INFO:tensorflow:global_step/sec: 503.718 INFO:tensorflow:loss = 0.312344, step = 4800 (0.198 sec) INFO:tensorflow:loss = 0.312344, step = 4800 (0.198 sec) INFO:tensorflow:global_step/sec: 497.659 INFO:tensorflow:global_step/sec: 497.659 INFO:tensorflow:loss = 0.37370217, step = 4900 (0.201 sec) INFO:tensorflow:loss = 0.37370217, step = 4900 (0.201 sec) INFO:tensorflow:global_step/sec: 477.736 INFO:tensorflow:global_step/sec: 477.736 INFO:tensorflow:loss = 0.2663591, step = 5000 (0.209 sec) INFO:tensorflow:loss = 0.2663591, step = 5000 (0.209 sec) INFO:tensorflow:global_step/sec: 496.559 INFO:tensorflow:global_step/sec: 496.559 INFO:tensorflow:loss = 0.34745598, step = 5100 (0.202 sec) INFO:tensorflow:loss = 0.34745598, step = 5100 (0.202 sec) INFO:tensorflow:global_step/sec: 475.989 INFO:tensorflow:global_step/sec: 475.989 INFO:tensorflow:loss = 0.21809828, step = 5200 (0.210 sec) INFO:tensorflow:loss = 0.21809828, step = 5200 (0.210 sec) INFO:tensorflow:global_step/sec: 474.464 INFO:tensorflow:global_step/sec: 474.464 INFO:tensorflow:loss = 0.2474105, step = 5300 (0.211 sec) INFO:tensorflow:loss = 0.2474105, step = 5300 (0.211 sec) INFO:tensorflow:global_step/sec: 488.774 INFO:tensorflow:global_step/sec: 488.774 INFO:tensorflow:loss = 0.1611641, step = 5400 (0.204 sec) INFO:tensorflow:loss = 0.1611641, step = 5400 (0.204 sec) INFO:tensorflow:global_step/sec: 504.942 INFO:tensorflow:global_step/sec: 504.942 INFO:tensorflow:loss = 0.2306528, step = 5500 (0.198 sec) INFO:tensorflow:loss = 0.2306528, step = 5500 (0.198 sec) INFO:tensorflow:global_step/sec: 514.058 INFO:tensorflow:global_step/sec: 514.058 INFO:tensorflow:loss = 0.20716992, step = 5600 (0.195 sec) INFO:tensorflow:loss = 0.20716992, step = 5600 (0.195 sec) INFO:tensorflow:global_step/sec: 458.899 INFO:tensorflow:global_step/sec: 458.899 INFO:tensorflow:loss = 0.16730343, step = 5700 (0.217 sec) INFO:tensorflow:loss = 0.16730343, step = 5700 (0.217 sec) INFO:tensorflow:global_step/sec: 495.197 INFO:tensorflow:global_step/sec: 495.197 INFO:tensorflow:loss = 0.2906361, step = 5800 (0.202 sec) INFO:tensorflow:loss = 0.2906361, step = 5800 (0.202 sec) INFO:tensorflow:global_step/sec: 482.244 INFO:tensorflow:global_step/sec: 482.244 INFO:tensorflow:loss = 0.24669808, step = 5900 (0.207 sec) INFO:tensorflow:loss = 0.24669808, step = 5900 (0.207 sec) INFO:tensorflow:global_step/sec: 484.946 INFO:tensorflow:global_step/sec: 484.946 INFO:tensorflow:loss = 0.26403594, step = 6000 (0.207 sec) INFO:tensorflow:loss = 0.26403594, step = 6000 (0.207 sec) INFO:tensorflow:global_step/sec: 486.74 INFO:tensorflow:global_step/sec: 486.74 INFO:tensorflow:loss = 0.19804293, step = 6100 (0.206 sec) INFO:tensorflow:loss = 0.19804293, step = 6100 (0.206 sec) INFO:tensorflow:global_step/sec: 436.727 INFO:tensorflow:global_step/sec: 436.727 INFO:tensorflow:loss = 0.25344175, step = 6200 (0.229 sec) INFO:tensorflow:loss = 0.25344175, step = 6200 (0.229 sec) INFO:tensorflow:global_step/sec: 428.73 INFO:tensorflow:global_step/sec: 428.73 INFO:tensorflow:loss = 0.2430937, step = 6300 (0.232 sec) INFO:tensorflow:loss = 0.2430937, step = 6300 (0.232 sec) INFO:tensorflow:global_step/sec: 449.706 INFO:tensorflow:global_step/sec: 449.706 INFO:tensorflow:loss = 0.2842306, step = 6400 (0.222 sec) INFO:tensorflow:loss = 0.2842306, step = 6400 (0.222 sec) INFO:tensorflow:global_step/sec: 440.873 INFO:tensorflow:global_step/sec: 440.873 INFO:tensorflow:loss = 0.2641199, step = 6500 (0.227 sec) INFO:tensorflow:loss = 0.2641199, step = 6500 (0.227 sec) INFO:tensorflow:global_step/sec: 424.092 INFO:tensorflow:global_step/sec: 424.092 INFO:tensorflow:loss = 0.19028814, step = 6600 (0.237 sec) INFO:tensorflow:loss = 0.19028814, step = 6600 (0.237 sec) INFO:tensorflow:global_step/sec: 450.352 INFO:tensorflow:global_step/sec: 450.352 INFO:tensorflow:loss = 0.24667627, step = 6700 (0.221 sec) INFO:tensorflow:loss = 0.24667627, step = 6700 (0.221 sec) INFO:tensorflow:global_step/sec: 462.774 INFO:tensorflow:global_step/sec: 462.774 INFO:tensorflow:loss = 0.40046322, step = 6800 (0.216 sec) INFO:tensorflow:loss = 0.40046322, step = 6800 (0.216 sec) INFO:tensorflow:global_step/sec: 460.854 INFO:tensorflow:global_step/sec: 460.854 INFO:tensorflow:loss = 0.14105138, step = 6900 (0.217 sec) INFO:tensorflow:loss = 0.14105138, step = 6900 (0.217 sec) INFO:tensorflow:Requesting early stopping at global step 6916 INFO:tensorflow:Requesting early stopping at global step 6916 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6917... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6917... INFO:tensorflow:Saving checkpoints for 6917 into /tmp/tmpocmc6_bo/model.ckpt. INFO:tensorflow:Saving checkpoints for 6917 into /tmp/tmpocmc6_bo/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6917... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6917... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-09-22T20:07:35 INFO:tensorflow:Starting evaluation at 2021-09-22T20:07:35 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpocmc6_bo/model.ckpt-6917 INFO:tensorflow:Restoring parameters from /tmp/tmpocmc6_bo/model.ckpt-6917 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Inference Time : 0.79520s INFO:tensorflow:Inference Time : 0.79520s INFO:tensorflow:Finished evaluation at 2021-09-22-20:07:36 INFO:tensorflow:Finished evaluation at 2021-09-22-20:07:36 INFO:tensorflow:Saving dict for global step 6917: global_step = 6917, loss = 0.227278 INFO:tensorflow:Saving dict for global step 6917: global_step = 6917, loss = 0.227278 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6917: /tmp/tmpocmc6_bo/model.ckpt-6917 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6917: /tmp/tmpocmc6_bo/model.ckpt-6917 INFO:tensorflow:Loss for final step: 0.13882703. INFO:tensorflow:Loss for final step: 0.13882703. ({'loss': 0.227278, 'global_step': 6917}, [])
TensorFlow 2: ранняя остановка со встроенным обратным вызовом и Model.fit
Подготовьте набор данных MNIST и простую модель Keras:
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.005),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
В TensorFlow 2, когда вы используете встроенный Model.fit
(или Model.evaluate
), вы можете настроить раннюю остановку, передав встроенный обратный вызов — tf.keras.callbacks.EarlyStopping
— в параметр callbacks
Model.fit
.
Обратный вызов EarlyStopping
отслеживает указанную пользователем метрику и завершает обучение, когда она перестает улучшаться. (Дополнительную информацию см. в разделе « Обучение и оценка с помощью встроенных методов » или в документации по API .)
Ниже приведен пример обратного вызова с ранней остановкой, который отслеживает потери и останавливает обучение после того, как количество эпох, в которых не наблюдается никаких улучшений, установлено на 3
( patience
):
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 5s 8ms/step - loss: 0.2371 - sparse_categorical_accuracy: 0.9293 - val_loss: 0.1334 - val_sparse_categorical_accuracy: 0.9611 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.1028 - sparse_categorical_accuracy: 0.9686 - val_loss: 0.1062 - val_sparse_categorical_accuracy: 0.9667 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0703 - sparse_categorical_accuracy: 0.9783 - val_loss: 0.0993 - val_sparse_categorical_accuracy: 0.9707 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0552 - sparse_categorical_accuracy: 0.9822 - val_loss: 0.1040 - val_sparse_categorical_accuracy: 0.9680 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0420 - sparse_categorical_accuracy: 0.9865 - val_loss: 0.1033 - val_sparse_categorical_accuracy: 0.9716 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0387 - sparse_categorical_accuracy: 0.9871 - val_loss: 0.1167 - val_sparse_categorical_accuracy: 0.9691 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0321 - sparse_categorical_accuracy: 0.9893 - val_loss: 0.1396 - val_sparse_categorical_accuracy: 0.9672 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0285 - sparse_categorical_accuracy: 0.9902 - val_loss: 0.1397 - val_sparse_categorical_accuracy: 0.9671 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0263 - sparse_categorical_accuracy: 0.9915 - val_loss: 0.1296 - val_sparse_categorical_accuracy: 0.9715 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0250 - sparse_categorical_accuracy: 0.9915 - val_loss: 0.1440 - val_sparse_categorical_accuracy: 0.9715 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0274 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.1439 - val_sparse_categorical_accuracy: 0.9710 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0241 - sparse_categorical_accuracy: 0.9923 - val_loss: 0.1429 - val_sparse_categorical_accuracy: 0.9718 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0205 - sparse_categorical_accuracy: 0.9929 - val_loss: 0.1451 - val_sparse_categorical_accuracy: 0.9753 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0196 - sparse_categorical_accuracy: 0.9936 - val_loss: 0.1562 - val_sparse_categorical_accuracy: 0.9750 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0214 - sparse_categorical_accuracy: 0.9930 - val_loss: 0.1531 - val_sparse_categorical_accuracy: 0.9748 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0178 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1712 - val_sparse_categorical_accuracy: 0.9731 Epoch 17/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0177 - sparse_categorical_accuracy: 0.9947 - val_loss: 0.1715 - val_sparse_categorical_accuracy: 0.9755 Epoch 18/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1826 - val_sparse_categorical_accuracy: 0.9730 Epoch 19/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0188 - sparse_categorical_accuracy: 0.9942 - val_loss: 0.1919 - val_sparse_categorical_accuracy: 0.9732 Epoch 20/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0190 - sparse_categorical_accuracy: 0.9944 - val_loss: 0.1703 - val_sparse_categorical_accuracy: 0.9777 Epoch 21/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0153 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.1725 - val_sparse_categorical_accuracy: 0.9764 21
TensorFlow 2: ранняя остановка с помощью пользовательского обратного вызова и Model.fit
Вы также можете реализовать собственный обратный вызов для ранней остановки , который также можно передать параметру callbacks
в Model.fit
(или Model.evaluate
).
В этом примере процесс обучения останавливается, когда self.model.stop_training
установлено значение True
:
class LimitTrainingTime(tf.keras.callbacks.Callback):
def __init__(self, max_time_s):
super().__init__()
self.max_time_s = max_time_s
self.start_time = None
def on_train_begin(self, logs):
self.start_time = time.time()
def on_train_batch_end(self, batch, logs):
now = time.time()
if now - self.start_time > self.max_time_s:
self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0131 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.1911 - val_sparse_categorical_accuracy: 0.9749 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0133 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1999 - val_sparse_categorical_accuracy: 0.9755 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0153 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1927 - val_sparse_categorical_accuracy: 0.9770 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0145 - sparse_categorical_accuracy: 0.9957 - val_loss: 0.2279 - val_sparse_categorical_accuracy: 0.9753 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2272 - val_sparse_categorical_accuracy: 0.9755 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0132 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2352 - val_sparse_categorical_accuracy: 0.9747 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0144 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.2421 - val_sparse_categorical_accuracy: 0.9734 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.2260 - val_sparse_categorical_accuracy: 0.9785 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.2472 - val_sparse_categorical_accuracy: 0.9752 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.2166 - val_sparse_categorical_accuracy: 0.9768 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0145 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2289 - val_sparse_categorical_accuracy: 0.9781 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0119 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.2310 - val_sparse_categorical_accuracy: 0.9777 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0144 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2617 - val_sparse_categorical_accuracy: 0.9781 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0119 - sparse_categorical_accuracy: 0.9972 - val_loss: 0.3007 - val_sparse_categorical_accuracy: 0.9754 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0150 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.3014 - val_sparse_categorical_accuracy: 0.9767 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2815 - val_sparse_categorical_accuracy: 0.9750 Epoch 17/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2606 - val_sparse_categorical_accuracy: 0.9765 Epoch 18/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0103 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.2602 - val_sparse_categorical_accuracy: 0.9777 Epoch 19/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0098 - sparse_categorical_accuracy: 0.9979 - val_loss: 0.2594 - val_sparse_categorical_accuracy: 0.9780 Epoch 20/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.3008 - val_sparse_categorical_accuracy: 0.9755 Epoch 21/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0110 - sparse_categorical_accuracy: 0.9974 - val_loss: 0.2662 - val_sparse_categorical_accuracy: 0.9765 Epoch 22/100 469/469 [==============================] - 1s 1ms/step - loss: 0.0083 - sparse_categorical_accuracy: 0.9978 - val_loss: 0.2587 - val_sparse_categorical_accuracy: 0.9797 22
TensorFlow 2: ранняя остановка с пользовательским циклом обучения
В TensorFlow 2 вы можете реализовать раннюю остановку в пользовательском цикле обучения, если вы не проводите обучение и оценку с помощью встроенных методов Keras .
Начните с использования API-интерфейсов Keras для определения другой простой модели, оптимизатора, функции потерь и метрик:
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
Определите функции обновления параметров с помощью tf.GradientTape и декоратора @tf.function
для ускорения :
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
train_loss_metric.update_state(y, logits)
return loss_value
@tf.function
def test_step(x, y):
logits = model(x, training=False)
val_acc_metric.update_state(y, logits)
val_loss_metric.update_state(y, logits)
Затем напишите собственный обучающий цикл, в котором вы можете вручную реализовать правило ранней остановки.
В приведенном ниже примере показано, как остановить обучение, когда потеря проверки не улучшается в течение определенного количества эпох:
epochs = 100
patience = 5
wait = 0
best = 0
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
loss_value = train_step(x_batch_train, y_batch_train)
if step % 200 == 0:
print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
print("Seen so far: %s samples" % ((step + 1) * 128))
train_acc = train_acc_metric.result()
train_loss = train_loss_metric.result()
train_acc_metric.reset_states()
train_loss_metric.reset_states()
print("Training acc over epoch: %.4f" % (train_acc.numpy()))
for x_batch_val, y_batch_val in ds_test:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_loss = val_loss_metric.result()
val_acc_metric.reset_states()
val_loss_metric.reset_states()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
# The early stopping strategy: stop the training if `val_loss` does not
# decrease over a certain number of epochs.
wait += 1
if val_loss > best:
best = val_loss
wait = 0
if wait >= patience:
break
Start of epoch 0 Training loss at step 0: 2.3073 Seen so far: 128 samples Training loss at step 200: 0.2164 Seen so far: 25728 samples Training loss at step 400: 0.2186 Seen so far: 51328 samples Training acc over epoch: 0.9321 Validation acc: 0.9644 Time taken: 1.73s Start of epoch 1 Training loss at step 0: 0.0733 Seen so far: 128 samples Training loss at step 200: 0.1581 Seen so far: 25728 samples Training loss at step 400: 0.1625 Seen so far: 51328 samples Training acc over epoch: 0.9704 Validation acc: 0.9681 Time taken: 1.23s Start of epoch 2 Training loss at step 0: 0.0501 Seen so far: 128 samples Training loss at step 200: 0.1389 Seen so far: 25728 samples Training loss at step 400: 0.1495 Seen so far: 51328 samples Training acc over epoch: 0.9779 Validation acc: 0.9703 Time taken: 1.17s Start of epoch 3 Training loss at step 0: 0.0513 Seen so far: 128 samples Training loss at step 200: 0.0638 Seen so far: 25728 samples Training loss at step 400: 0.0930 Seen so far: 51328 samples Training acc over epoch: 0.9830 Validation acc: 0.9719 Time taken: 1.20s Start of epoch 4 Training loss at step 0: 0.0251 Seen so far: 128 samples Training loss at step 200: 0.0482 Seen so far: 25728 samples Training loss at step 400: 0.0872 Seen so far: 51328 samples Training acc over epoch: 0.9849 Validation acc: 0.9672 Time taken: 1.18s Start of epoch 5 Training loss at step 0: 0.0417 Seen so far: 128 samples Training loss at step 200: 0.0302 Seen so far: 25728 samples Training loss at step 400: 0.0362 Seen so far: 51328 samples Training acc over epoch: 0.9878 Validation acc: 0.9703 Time taken: 1.21s
Следующие шаги
- Узнайте больше о встроенном в Keras API обратного вызова для ранней остановки в документации по API .
- Научитесь писать собственные обратные вызовы Keras , включая раннюю остановку с минимальными потерями .
- Узнайте об обучении и оценке с помощью встроенных методов Keras .
- Изучите общие методы регуляризации в учебнике Overfit and underfit , в котором используется обратный вызов
EarlyStopping
.