迁移提前停止

在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本

本笔记本演示了如何使用提前停止设置模型训练。首先,在 TensorFlow 1 中使用 tf.estimator.Estimator 和提前停止钩子,然后在 TensorFlow 2 中使用 Keras API 或自定义训练循环。 提前停止是一种正则化技术,可在验证损失达到特定阈值时停止训练。

在 TensorFlow 2 中,可以通过三种方式实现提前停止:

安装

import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
2022-08-31 00:12:52.034051: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-31 00:12:52.618453: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-31 00:12:52.618694: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-31 00:12:52.618706: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

TensorFlow 1:使用提前停止钩子和 tf.estimator 提前停止

首先,定义用于 MNIST 数据集加载和预处理的函数,以及与 tf.estimator.Estimator 一起使用的模型定义:

def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def _input_fn():
  ds_train = tfds.load(
    name='mnist',
    split='train',
    shuffle_files=True,
    as_supervised=True)

  ds_train = ds_train.map(
      normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_train = ds_train.batch(128)
  ds_train = ds_train.repeat(100)
  return ds_train

def _eval_input_fn():
  ds_test = tfds.load(
    name='mnist',
    split='test',
    shuffle_files=True,
    as_supervised=True)
  ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_test = ds_test.batch(128)
  return ds_test

def _model_fn(features, labels, mode):
  flatten = tf1.layers.Flatten()(features)
  features = tf1.layers.Dense(128, 'relu')(flatten)
  logits = tf1.layers.Dense(10)(features)

  loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  optimizer = tf1.train.AdagradOptimizer(0.005)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())

  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

在 TensorFlow 1 中,提前停止的工作方式是使用 tf.estimator.experimental.make_early_stopping_hook 设置提前停止钩子。将钩子传递给 make_early_stopping_hook 方法作为 should_stop_fn 的参数,它可以接受不带任何参数的函数。一旦 should_stop_fn 返回 True,训练就会停止。

下面的示例演示了如何实现将训练时间限制为最多 20 秒的提前停止技术:

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

start_time = time.time()
max_train_seconds = 20

def should_stop_fn():
  return time.time() - start_time > max_train_seconds

early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
    estimator=estimator,
    should_stop_fn=should_stop_fn,
    run_every_secs=1,
    run_every_steps=None)

train_spec = tf1.estimator.TrainSpec(
    input_fn=_input_fn,
    hooks=[early_stopping_hook])

eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)

tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp9nqn356r
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp9nqn356r', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp9nqn356r/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp9nqn356r/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.3256147, step = 0
INFO:tensorflow:loss = 2.3256147, step = 0
INFO:tensorflow:global_step/sec: 169.311
INFO:tensorflow:global_step/sec: 169.311
INFO:tensorflow:loss = 1.4084164, step = 100 (0.592 sec)
INFO:tensorflow:loss = 1.4084164, step = 100 (0.592 sec)
INFO:tensorflow:global_step/sec: 297.925
INFO:tensorflow:global_step/sec: 297.925
INFO:tensorflow:loss = 0.8429121, step = 200 (0.336 sec)
INFO:tensorflow:loss = 0.8429121, step = 200 (0.336 sec)
INFO:tensorflow:global_step/sec: 332.787
INFO:tensorflow:global_step/sec: 332.787
INFO:tensorflow:loss = 0.7294538, step = 300 (0.301 sec)
INFO:tensorflow:loss = 0.7294538, step = 300 (0.301 sec)
INFO:tensorflow:global_step/sec: 329.617
INFO:tensorflow:global_step/sec: 329.617
INFO:tensorflow:loss = 0.6392085, step = 400 (0.303 sec)
INFO:tensorflow:loss = 0.6392085, step = 400 (0.303 sec)
INFO:tensorflow:global_step/sec: 348.077
INFO:tensorflow:global_step/sec: 348.077
INFO:tensorflow:loss = 0.52678365, step = 500 (0.287 sec)
INFO:tensorflow:loss = 0.52678365, step = 500 (0.287 sec)
INFO:tensorflow:global_step/sec: 572.597
INFO:tensorflow:global_step/sec: 572.597
INFO:tensorflow:loss = 0.47424147, step = 600 (0.175 sec)
INFO:tensorflow:loss = 0.47424147, step = 600 (0.175 sec)
INFO:tensorflow:global_step/sec: 571.817
INFO:tensorflow:global_step/sec: 571.817
INFO:tensorflow:loss = 0.39473337, step = 700 (0.175 sec)
INFO:tensorflow:loss = 0.39473337, step = 700 (0.175 sec)
INFO:tensorflow:global_step/sec: 575.548
INFO:tensorflow:global_step/sec: 575.548
INFO:tensorflow:loss = 0.5178499, step = 800 (0.174 sec)
INFO:tensorflow:loss = 0.5178499, step = 800 (0.174 sec)
INFO:tensorflow:global_step/sec: 584.312
INFO:tensorflow:global_step/sec: 584.312
INFO:tensorflow:loss = 0.380741, step = 900 (0.171 sec)
INFO:tensorflow:loss = 0.380741, step = 900 (0.171 sec)
INFO:tensorflow:global_step/sec: 514.355
INFO:tensorflow:global_step/sec: 514.355
INFO:tensorflow:loss = 0.4594915, step = 1000 (0.195 sec)
INFO:tensorflow:loss = 0.4594915, step = 1000 (0.195 sec)
INFO:tensorflow:global_step/sec: 575.808
INFO:tensorflow:global_step/sec: 575.808
INFO:tensorflow:loss = 0.4344558, step = 1100 (0.173 sec)
INFO:tensorflow:loss = 0.4344558, step = 1100 (0.173 sec)
INFO:tensorflow:global_step/sec: 573.102
INFO:tensorflow:global_step/sec: 573.102
INFO:tensorflow:loss = 0.39675385, step = 1200 (0.175 sec)
INFO:tensorflow:loss = 0.39675385, step = 1200 (0.175 sec)
INFO:tensorflow:global_step/sec: 578.706
INFO:tensorflow:global_step/sec: 578.706
INFO:tensorflow:loss = 0.4755693, step = 1300 (0.173 sec)
INFO:tensorflow:loss = 0.4755693, step = 1300 (0.173 sec)
INFO:tensorflow:global_step/sec: 574.495
INFO:tensorflow:global_step/sec: 574.495
INFO:tensorflow:loss = 0.3039328, step = 1400 (0.174 sec)
INFO:tensorflow:loss = 0.3039328, step = 1400 (0.174 sec)
INFO:tensorflow:global_step/sec: 496.976
INFO:tensorflow:global_step/sec: 496.976
INFO:tensorflow:loss = 0.30139494, step = 1500 (0.201 sec)
INFO:tensorflow:loss = 0.30139494, step = 1500 (0.201 sec)
INFO:tensorflow:global_step/sec: 551.468
INFO:tensorflow:global_step/sec: 551.468
INFO:tensorflow:loss = 0.4051479, step = 1600 (0.182 sec)
INFO:tensorflow:loss = 0.4051479, step = 1600 (0.182 sec)
INFO:tensorflow:global_step/sec: 584.976
INFO:tensorflow:global_step/sec: 584.976
INFO:tensorflow:loss = 0.41139278, step = 1700 (0.171 sec)
INFO:tensorflow:loss = 0.41139278, step = 1700 (0.171 sec)
INFO:tensorflow:global_step/sec: 571.57
INFO:tensorflow:global_step/sec: 571.57
INFO:tensorflow:loss = 0.3128811, step = 1800 (0.175 sec)
INFO:tensorflow:loss = 0.3128811, step = 1800 (0.175 sec)
INFO:tensorflow:global_step/sec: 531.12
INFO:tensorflow:global_step/sec: 531.12
INFO:tensorflow:loss = 0.53163254, step = 1900 (0.188 sec)
INFO:tensorflow:loss = 0.53163254, step = 1900 (0.188 sec)
INFO:tensorflow:global_step/sec: 579.244
INFO:tensorflow:global_step/sec: 579.244
INFO:tensorflow:loss = 0.20764115, step = 2000 (0.173 sec)
INFO:tensorflow:loss = 0.20764115, step = 2000 (0.173 sec)
INFO:tensorflow:global_step/sec: 536.316
INFO:tensorflow:global_step/sec: 536.316
INFO:tensorflow:loss = 0.28562748, step = 2100 (0.186 sec)
INFO:tensorflow:loss = 0.28562748, step = 2100 (0.186 sec)
INFO:tensorflow:global_step/sec: 574.273
INFO:tensorflow:global_step/sec: 574.273
INFO:tensorflow:loss = 0.3138504, step = 2200 (0.174 sec)
INFO:tensorflow:loss = 0.3138504, step = 2200 (0.174 sec)
INFO:tensorflow:global_step/sec: 572.667
INFO:tensorflow:global_step/sec: 572.667
INFO:tensorflow:loss = 0.33238938, step = 2300 (0.174 sec)
INFO:tensorflow:loss = 0.33238938, step = 2300 (0.174 sec)
INFO:tensorflow:global_step/sec: 522.621
INFO:tensorflow:global_step/sec: 522.621
INFO:tensorflow:loss = 0.24549136, step = 2400 (0.191 sec)
INFO:tensorflow:loss = 0.24549136, step = 2400 (0.191 sec)
INFO:tensorflow:global_step/sec: 574.483
INFO:tensorflow:global_step/sec: 574.483
INFO:tensorflow:loss = 0.21948919, step = 2500 (0.174 sec)
INFO:tensorflow:loss = 0.21948919, step = 2500 (0.174 sec)
INFO:tensorflow:global_step/sec: 568.095
INFO:tensorflow:global_step/sec: 568.095
INFO:tensorflow:loss = 0.15908587, step = 2600 (0.177 sec)
INFO:tensorflow:loss = 0.15908587, step = 2600 (0.177 sec)
INFO:tensorflow:global_step/sec: 538.559
INFO:tensorflow:global_step/sec: 538.559
INFO:tensorflow:loss = 0.3201595, step = 2700 (0.185 sec)
INFO:tensorflow:loss = 0.3201595, step = 2700 (0.185 sec)
INFO:tensorflow:global_step/sec: 578.767
INFO:tensorflow:global_step/sec: 578.767
INFO:tensorflow:loss = 0.4630896, step = 2800 (0.173 sec)
INFO:tensorflow:loss = 0.4630896, step = 2800 (0.173 sec)
INFO:tensorflow:global_step/sec: 507.109
INFO:tensorflow:global_step/sec: 507.109
INFO:tensorflow:loss = 0.25870368, step = 2900 (0.197 sec)
INFO:tensorflow:loss = 0.25870368, step = 2900 (0.197 sec)
INFO:tensorflow:global_step/sec: 572.554
INFO:tensorflow:global_step/sec: 572.554
INFO:tensorflow:loss = 0.32750684, step = 3000 (0.174 sec)
INFO:tensorflow:loss = 0.32750684, step = 3000 (0.174 sec)
INFO:tensorflow:global_step/sec: 533.407
INFO:tensorflow:global_step/sec: 533.407
INFO:tensorflow:loss = 0.21683025, step = 3100 (0.188 sec)
INFO:tensorflow:loss = 0.21683025, step = 3100 (0.188 sec)
INFO:tensorflow:global_step/sec: 504.484
INFO:tensorflow:global_step/sec: 504.484
INFO:tensorflow:loss = 0.42377055, step = 3200 (0.198 sec)
INFO:tensorflow:loss = 0.42377055, step = 3200 (0.198 sec)
INFO:tensorflow:global_step/sec: 495.595
INFO:tensorflow:global_step/sec: 495.595
INFO:tensorflow:loss = 0.3331779, step = 3300 (0.202 sec)
INFO:tensorflow:loss = 0.3331779, step = 3300 (0.202 sec)
INFO:tensorflow:global_step/sec: 527.809
INFO:tensorflow:global_step/sec: 527.809
INFO:tensorflow:loss = 0.25893703, step = 3400 (0.190 sec)
INFO:tensorflow:loss = 0.25893703, step = 3400 (0.190 sec)
INFO:tensorflow:global_step/sec: 562.23
INFO:tensorflow:global_step/sec: 562.23
INFO:tensorflow:loss = 0.21271876, step = 3500 (0.177 sec)
INFO:tensorflow:loss = 0.21271876, step = 3500 (0.177 sec)
INFO:tensorflow:global_step/sec: 572.326
INFO:tensorflow:global_step/sec: 572.326
INFO:tensorflow:loss = 0.21785083, step = 3600 (0.175 sec)
INFO:tensorflow:loss = 0.21785083, step = 3600 (0.175 sec)
INFO:tensorflow:global_step/sec: 536.469
INFO:tensorflow:global_step/sec: 536.469
INFO:tensorflow:loss = 0.28410798, step = 3700 (0.187 sec)
INFO:tensorflow:loss = 0.28410798, step = 3700 (0.187 sec)
INFO:tensorflow:global_step/sec: 464.124
INFO:tensorflow:global_step/sec: 464.124
INFO:tensorflow:loss = 0.38055325, step = 3800 (0.215 sec)
INFO:tensorflow:loss = 0.38055325, step = 3800 (0.215 sec)
INFO:tensorflow:global_step/sec: 569.71
INFO:tensorflow:global_step/sec: 569.71
INFO:tensorflow:loss = 0.22856858, step = 3900 (0.176 sec)
INFO:tensorflow:loss = 0.22856858, step = 3900 (0.176 sec)
INFO:tensorflow:global_step/sec: 572.71
INFO:tensorflow:global_step/sec: 572.71
INFO:tensorflow:loss = 0.2895166, step = 4000 (0.174 sec)
INFO:tensorflow:loss = 0.2895166, step = 4000 (0.174 sec)
INFO:tensorflow:global_step/sec: 574.271
INFO:tensorflow:global_step/sec: 574.271
INFO:tensorflow:loss = 0.2146263, step = 4100 (0.174 sec)
INFO:tensorflow:loss = 0.2146263, step = 4100 (0.174 sec)
INFO:tensorflow:global_step/sec: 570.225
INFO:tensorflow:global_step/sec: 570.225
INFO:tensorflow:loss = 0.2812382, step = 4200 (0.176 sec)
INFO:tensorflow:loss = 0.2812382, step = 4200 (0.176 sec)
INFO:tensorflow:global_step/sec: 518.09
INFO:tensorflow:global_step/sec: 518.09
INFO:tensorflow:loss = 0.2987792, step = 4300 (0.193 sec)
INFO:tensorflow:loss = 0.2987792, step = 4300 (0.193 sec)
INFO:tensorflow:global_step/sec: 574.251
INFO:tensorflow:global_step/sec: 574.251
INFO:tensorflow:loss = 0.32051677, step = 4400 (0.174 sec)
INFO:tensorflow:loss = 0.32051677, step = 4400 (0.174 sec)
INFO:tensorflow:global_step/sec: 572.393
INFO:tensorflow:global_step/sec: 572.393
INFO:tensorflow:loss = 0.26353228, step = 4500 (0.175 sec)
INFO:tensorflow:loss = 0.26353228, step = 4500 (0.175 sec)
INFO:tensorflow:global_step/sec: 574.802
INFO:tensorflow:global_step/sec: 574.802
INFO:tensorflow:loss = 0.31429395, step = 4600 (0.174 sec)
INFO:tensorflow:loss = 0.31429395, step = 4600 (0.174 sec)
INFO:tensorflow:global_step/sec: 523.816
INFO:tensorflow:global_step/sec: 523.816
INFO:tensorflow:loss = 0.14255509, step = 4700 (0.191 sec)
INFO:tensorflow:loss = 0.14255509, step = 4700 (0.191 sec)
INFO:tensorflow:global_step/sec: 563.154
INFO:tensorflow:global_step/sec: 563.154
INFO:tensorflow:loss = 0.2682479, step = 4800 (0.178 sec)
INFO:tensorflow:loss = 0.2682479, step = 4800 (0.178 sec)
INFO:tensorflow:global_step/sec: 528.399
INFO:tensorflow:global_step/sec: 528.399
INFO:tensorflow:loss = 0.37475693, step = 4900 (0.189 sec)
INFO:tensorflow:loss = 0.37475693, step = 4900 (0.189 sec)
INFO:tensorflow:global_step/sec: 565.211
INFO:tensorflow:global_step/sec: 565.211
INFO:tensorflow:loss = 0.27868846, step = 5000 (0.177 sec)
INFO:tensorflow:loss = 0.27868846, step = 5000 (0.177 sec)
INFO:tensorflow:global_step/sec: 565.509
INFO:tensorflow:global_step/sec: 565.509
INFO:tensorflow:loss = 0.33586797, step = 5100 (0.177 sec)
INFO:tensorflow:loss = 0.33586797, step = 5100 (0.177 sec)
INFO:tensorflow:global_step/sec: 518.517
INFO:tensorflow:global_step/sec: 518.517
INFO:tensorflow:loss = 0.21364841, step = 5200 (0.193 sec)
INFO:tensorflow:loss = 0.21364841, step = 5200 (0.193 sec)
INFO:tensorflow:global_step/sec: 567.92
INFO:tensorflow:global_step/sec: 567.92
INFO:tensorflow:loss = 0.23852614, step = 5300 (0.176 sec)
INFO:tensorflow:loss = 0.23852614, step = 5300 (0.176 sec)
INFO:tensorflow:global_step/sec: 575.316
INFO:tensorflow:global_step/sec: 575.316
INFO:tensorflow:loss = 0.16083148, step = 5400 (0.174 sec)
INFO:tensorflow:loss = 0.16083148, step = 5400 (0.174 sec)
INFO:tensorflow:global_step/sec: 582.596
INFO:tensorflow:global_step/sec: 582.596
INFO:tensorflow:loss = 0.2215988, step = 5500 (0.172 sec)
INFO:tensorflow:loss = 0.2215988, step = 5500 (0.172 sec)
INFO:tensorflow:global_step/sec: 580.536
INFO:tensorflow:global_step/sec: 580.536
INFO:tensorflow:loss = 0.17718056, step = 5600 (0.172 sec)
INFO:tensorflow:loss = 0.17718056, step = 5600 (0.172 sec)
INFO:tensorflow:global_step/sec: 488.131
INFO:tensorflow:global_step/sec: 488.131
INFO:tensorflow:loss = 0.16356084, step = 5700 (0.205 sec)
INFO:tensorflow:loss = 0.16356084, step = 5700 (0.205 sec)
INFO:tensorflow:global_step/sec: 548.261
INFO:tensorflow:global_step/sec: 548.261
INFO:tensorflow:loss = 0.2838351, step = 5800 (0.183 sec)
INFO:tensorflow:loss = 0.2838351, step = 5800 (0.183 sec)
INFO:tensorflow:global_step/sec: 562.286
INFO:tensorflow:global_step/sec: 562.286
INFO:tensorflow:loss = 0.2158921, step = 5900 (0.178 sec)
INFO:tensorflow:loss = 0.2158921, step = 5900 (0.178 sec)
INFO:tensorflow:global_step/sec: 562.004
INFO:tensorflow:global_step/sec: 562.004
INFO:tensorflow:loss = 0.2538747, step = 6000 (0.178 sec)
INFO:tensorflow:loss = 0.2538747, step = 6000 (0.178 sec)
INFO:tensorflow:global_step/sec: 517.104
INFO:tensorflow:global_step/sec: 517.104
INFO:tensorflow:loss = 0.18471983, step = 6100 (0.193 sec)
INFO:tensorflow:loss = 0.18471983, step = 6100 (0.193 sec)
INFO:tensorflow:global_step/sec: 584.601
INFO:tensorflow:global_step/sec: 584.601
INFO:tensorflow:loss = 0.25000086, step = 6200 (0.171 sec)
INFO:tensorflow:loss = 0.25000086, step = 6200 (0.171 sec)
INFO:tensorflow:global_step/sec: 570.692
INFO:tensorflow:global_step/sec: 570.692
INFO:tensorflow:loss = 0.2516309, step = 6300 (0.175 sec)
INFO:tensorflow:loss = 0.2516309, step = 6300 (0.175 sec)
INFO:tensorflow:global_step/sec: 569.891
INFO:tensorflow:global_step/sec: 569.891
INFO:tensorflow:loss = 0.28382745, step = 6400 (0.176 sec)
INFO:tensorflow:loss = 0.28382745, step = 6400 (0.176 sec)
INFO:tensorflow:global_step/sec: 570.507
INFO:tensorflow:global_step/sec: 570.507
INFO:tensorflow:loss = 0.26250225, step = 6500 (0.175 sec)
INFO:tensorflow:loss = 0.26250225, step = 6500 (0.175 sec)
INFO:tensorflow:global_step/sec: 525.937
INFO:tensorflow:global_step/sec: 525.937
INFO:tensorflow:loss = 0.19535686, step = 6600 (0.190 sec)
INFO:tensorflow:loss = 0.19535686, step = 6600 (0.190 sec)
INFO:tensorflow:global_step/sec: 546.51
INFO:tensorflow:global_step/sec: 546.51
INFO:tensorflow:loss = 0.2414819, step = 6700 (0.183 sec)
INFO:tensorflow:loss = 0.2414819, step = 6700 (0.183 sec)
INFO:tensorflow:global_step/sec: 547.428
INFO:tensorflow:global_step/sec: 547.428
INFO:tensorflow:loss = 0.41377378, step = 6800 (0.183 sec)
INFO:tensorflow:loss = 0.41377378, step = 6800 (0.183 sec)
INFO:tensorflow:global_step/sec: 582.703
INFO:tensorflow:global_step/sec: 582.703
INFO:tensorflow:loss = 0.13949364, step = 6900 (0.172 sec)
INFO:tensorflow:loss = 0.13949364, step = 6900 (0.172 sec)
INFO:tensorflow:global_step/sec: 575.777
INFO:tensorflow:global_step/sec: 575.777
INFO:tensorflow:loss = 0.33384934, step = 7000 (0.174 sec)
INFO:tensorflow:loss = 0.33384934, step = 7000 (0.174 sec)
INFO:tensorflow:global_step/sec: 522.722
INFO:tensorflow:global_step/sec: 522.722
INFO:tensorflow:loss = 0.1819811, step = 7100 (0.191 sec)
INFO:tensorflow:loss = 0.1819811, step = 7100 (0.191 sec)
INFO:tensorflow:global_step/sec: 575.838
INFO:tensorflow:global_step/sec: 575.838
INFO:tensorflow:loss = 0.2469174, step = 7200 (0.174 sec)
INFO:tensorflow:loss = 0.2469174, step = 7200 (0.174 sec)
INFO:tensorflow:global_step/sec: 557.654
INFO:tensorflow:global_step/sec: 557.654
INFO:tensorflow:loss = 0.18281877, step = 7300 (0.179 sec)
INFO:tensorflow:loss = 0.18281877, step = 7300 (0.179 sec)
INFO:tensorflow:global_step/sec: 568.367
INFO:tensorflow:global_step/sec: 568.367
INFO:tensorflow:loss = 0.26482195, step = 7400 (0.176 sec)
INFO:tensorflow:loss = 0.26482195, step = 7400 (0.176 sec)
INFO:tensorflow:global_step/sec: 572.098
INFO:tensorflow:global_step/sec: 572.098
INFO:tensorflow:loss = 0.13162345, step = 7500 (0.175 sec)
INFO:tensorflow:loss = 0.13162345, step = 7500 (0.175 sec)
INFO:tensorflow:global_step/sec: 484.593
INFO:tensorflow:global_step/sec: 484.593
INFO:tensorflow:loss = 0.23782606, step = 7600 (0.207 sec)
INFO:tensorflow:loss = 0.23782606, step = 7600 (0.207 sec)
INFO:tensorflow:global_step/sec: 542.566
INFO:tensorflow:global_step/sec: 542.566
INFO:tensorflow:loss = 0.25690144, step = 7700 (0.184 sec)
INFO:tensorflow:loss = 0.25690144, step = 7700 (0.184 sec)
INFO:tensorflow:Requesting early stopping at global step 7787
INFO:tensorflow:Requesting early stopping at global step 7787
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7788...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7788...
INFO:tensorflow:Saving checkpoints for 7788 into /tmpfs/tmp/tmp9nqn356r/model.ckpt.
INFO:tensorflow:Saving checkpoints for 7788 into /tmpfs/tmp/tmp9nqn356r/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7788...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7788...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-08-31T00:13:15
INFO:tensorflow:Starting evaluation at 2022-08-31T00:13:15
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9nqn356r/model.ckpt-7788
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9nqn356r/model.ckpt-7788
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Inference Time : 0.99282s
INFO:tensorflow:Inference Time : 0.99282s
INFO:tensorflow:Finished evaluation at 2022-08-31-00:13:16
INFO:tensorflow:Finished evaluation at 2022-08-31-00:13:16
INFO:tensorflow:Saving dict for global step 7788: global_step = 7788, loss = 0.2119798
INFO:tensorflow:Saving dict for global step 7788: global_step = 7788, loss = 0.2119798
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 7788: /tmpfs/tmp/tmp9nqn356r/model.ckpt-7788
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 7788: /tmpfs/tmp/tmp9nqn356r/model.ckpt-7788
INFO:tensorflow:Loss for final step: 0.30151576.
INFO:tensorflow:Loss for final step: 0.30151576.
({'loss': 0.2119798, 'global_step': 7788}, [])

TensorFlow 2:使用内置回调和 Model.fit 提前停止

准备 MNIST 数据集和一个简单的 Keras 模型:

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.005),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

在 TensorFlow 2 中,当您使用内置的 Keras Model.fit(或 Model.evaluate)时,可以通过将内置回调 tf.keras.callbacks.EarlyStopping 传递给 Model.fitcallbacks 参数来配置提前停止。

EarlyStopping 回调会监视用户指定的指标,并在停止改进时结束训练。(请查看使用内置方法进行训练和评估API 文档来了解详情。)

下面是一个提前停止回调的示例,它监视损失并在显示没有改进的周期数设置为 3 (patience) 后停止训练:

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)

# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)

len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 3s 5ms/step - loss: 0.2306 - sparse_categorical_accuracy: 0.9300 - val_loss: 0.1288 - val_sparse_categorical_accuracy: 0.9587
Epoch 2/100
469/469 [==============================] - 1s 2ms/step - loss: 0.1021 - sparse_categorical_accuracy: 0.9691 - val_loss: 0.1058 - val_sparse_categorical_accuracy: 0.9670
Epoch 3/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0704 - sparse_categorical_accuracy: 0.9784 - val_loss: 0.0982 - val_sparse_categorical_accuracy: 0.9701
Epoch 4/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0520 - sparse_categorical_accuracy: 0.9833 - val_loss: 0.1022 - val_sparse_categorical_accuracy: 0.9707
Epoch 5/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0414 - sparse_categorical_accuracy: 0.9866 - val_loss: 0.1099 - val_sparse_categorical_accuracy: 0.9708
Epoch 6/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0373 - sparse_categorical_accuracy: 0.9876 - val_loss: 0.0995 - val_sparse_categorical_accuracy: 0.9733
Epoch 7/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0344 - sparse_categorical_accuracy: 0.9882 - val_loss: 0.1064 - val_sparse_categorical_accuracy: 0.9739
Epoch 8/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0302 - sparse_categorical_accuracy: 0.9896 - val_loss: 0.1270 - val_sparse_categorical_accuracy: 0.9709
Epoch 9/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0280 - sparse_categorical_accuracy: 0.9898 - val_loss: 0.1401 - val_sparse_categorical_accuracy: 0.9706
Epoch 10/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0280 - sparse_categorical_accuracy: 0.9904 - val_loss: 0.1186 - val_sparse_categorical_accuracy: 0.9763
Epoch 11/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0232 - sparse_categorical_accuracy: 0.9925 - val_loss: 0.1362 - val_sparse_categorical_accuracy: 0.9744
Epoch 12/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0252 - sparse_categorical_accuracy: 0.9916 - val_loss: 0.1320 - val_sparse_categorical_accuracy: 0.9742
Epoch 13/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0184 - sparse_categorical_accuracy: 0.9939 - val_loss: 0.1386 - val_sparse_categorical_accuracy: 0.9753
Epoch 14/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0208 - sparse_categorical_accuracy: 0.9934 - val_loss: 0.1421 - val_sparse_categorical_accuracy: 0.9752
Epoch 15/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0187 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1520 - val_sparse_categorical_accuracy: 0.9767
Epoch 16/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0180 - sparse_categorical_accuracy: 0.9939 - val_loss: 0.1645 - val_sparse_categorical_accuracy: 0.9746
Epoch 17/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0184 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1617 - val_sparse_categorical_accuracy: 0.9750
Epoch 18/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0167 - sparse_categorical_accuracy: 0.9945 - val_loss: 0.1993 - val_sparse_categorical_accuracy: 0.9721
Epoch 19/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0200 - sparse_categorical_accuracy: 0.9938 - val_loss: 0.1581 - val_sparse_categorical_accuracy: 0.9778
Epoch 20/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0154 - sparse_categorical_accuracy: 0.9953 - val_loss: 0.1826 - val_sparse_categorical_accuracy: 0.9748
Epoch 21/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0150 - sparse_categorical_accuracy: 0.9954 - val_loss: 0.2079 - val_sparse_categorical_accuracy: 0.9751
Epoch 22/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0164 - sparse_categorical_accuracy: 0.9953 - val_loss: 0.1835 - val_sparse_categorical_accuracy: 0.9753
Epoch 23/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0148 - sparse_categorical_accuracy: 0.9954 - val_loss: 0.2001 - val_sparse_categorical_accuracy: 0.9776
Epoch 24/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0162 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.2203 - val_sparse_categorical_accuracy: 0.9731
Epoch 25/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0148 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1891 - val_sparse_categorical_accuracy: 0.9771
Epoch 26/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0120 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.1966 - val_sparse_categorical_accuracy: 0.9765
Epoch 27/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0132 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2041 - val_sparse_categorical_accuracy: 0.9764
Epoch 28/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0171 - sparse_categorical_accuracy: 0.9950 - val_loss: 0.2245 - val_sparse_categorical_accuracy: 0.9765
Epoch 29/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0170 - sparse_categorical_accuracy: 0.9955 - val_loss: 0.2325 - val_sparse_categorical_accuracy: 0.9759
29

TensorFlow 2:使用自定义回调和 Model.fit 提前停止

您也可以实现自定义的提前停止回调,此回调也可以传递给 Model.fit(或 Model.evaluate)的 callbacks 参数。

在此示例中,一旦 self.model.stop_training 设置为 True,训练过程就会停止:

class LimitTrainingTime(tf.keras.callbacks.Callback):
  def __init__(self, max_time_s):
    super().__init__()
    self.max_time_s = max_time_s
    self.start_time = None

  def on_train_begin(self, logs):
    self.start_time = time.time()

  def on_train_batch_end(self, batch, logs):
    now = time.time()
    if now - self.start_time >  self.max_time_s:
      self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2266 - val_sparse_categorical_accuracy: 0.9770
Epoch 2/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0113 - sparse_categorical_accuracy: 0.9972 - val_loss: 0.2376 - val_sparse_categorical_accuracy: 0.9755
Epoch 3/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0108 - sparse_categorical_accuracy: 0.9969 - val_loss: 0.2351 - val_sparse_categorical_accuracy: 0.9743
Epoch 4/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0105 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.2160 - val_sparse_categorical_accuracy: 0.9764
Epoch 5/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0120 - sparse_categorical_accuracy: 0.9969 - val_loss: 0.2606 - val_sparse_categorical_accuracy: 0.9728
Epoch 6/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0125 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.2290 - val_sparse_categorical_accuracy: 0.9765
Epoch 7/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0186 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2652 - val_sparse_categorical_accuracy: 0.9745
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0175 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.2816 - val_sparse_categorical_accuracy: 0.9755
Epoch 9/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0107 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.3001 - val_sparse_categorical_accuracy: 0.9730
Epoch 10/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0117 - sparse_categorical_accuracy: 0.9973 - val_loss: 0.2864 - val_sparse_categorical_accuracy: 0.9742
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0092 - sparse_categorical_accuracy: 0.9977 - val_loss: 0.2702 - val_sparse_categorical_accuracy: 0.9745
Epoch 12/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0100 - sparse_categorical_accuracy: 0.9974 - val_loss: 0.2953 - val_sparse_categorical_accuracy: 0.9746
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0119 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.3248 - val_sparse_categorical_accuracy: 0.9709
Epoch 14/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0171 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.2778 - val_sparse_categorical_accuracy: 0.9773
Epoch 15/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0116 - sparse_categorical_accuracy: 0.9972 - val_loss: 0.2742 - val_sparse_categorical_accuracy: 0.9774
Epoch 16/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0063 - sparse_categorical_accuracy: 0.9982 - val_loss: 0.2808 - val_sparse_categorical_accuracy: 0.9770
Epoch 17/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0114 - sparse_categorical_accuracy: 0.9976 - val_loss: 0.2934 - val_sparse_categorical_accuracy: 0.9778
Epoch 18/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0094 - sparse_categorical_accuracy: 0.9979 - val_loss: 0.3028 - val_sparse_categorical_accuracy: 0.9772
Epoch 19/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0091 - sparse_categorical_accuracy: 0.9982 - val_loss: 0.3114 - val_sparse_categorical_accuracy: 0.9763
Epoch 20/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0162 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.3448 - val_sparse_categorical_accuracy: 0.9758
Epoch 21/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0126 - sparse_categorical_accuracy: 0.9973 - val_loss: 0.3130 - val_sparse_categorical_accuracy: 0.9766
Epoch 22/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0094 - sparse_categorical_accuracy: 0.9980 - val_loss: 0.3528 - val_sparse_categorical_accuracy: 0.9745
Epoch 23/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0100 - sparse_categorical_accuracy: 0.9977 - val_loss: 0.3339 - val_sparse_categorical_accuracy: 0.9772
Epoch 24/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0120 - sparse_categorical_accuracy: 0.9977 - val_loss: 0.3521 - val_sparse_categorical_accuracy: 0.9744
Epoch 25/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0117 - sparse_categorical_accuracy: 0.9976 - val_loss: 0.3300 - val_sparse_categorical_accuracy: 0.9778
Epoch 26/100
469/469 [==============================] - 1s 1ms/step - loss: 0.0137 - sparse_categorical_accuracy: 0.9974 - val_loss: 0.3613 - val_sparse_categorical_accuracy: 0.9760
26

TensorFlow 2:使用自定义训练循环提前停止

在 TensorFlow 2 中,如果您不使用内置 Keras 方法进行训练和评估,则可以在自定义训练循环中实现提前停止。

首先,使用 Keras API 定义另一个简单的模型、优化器、损失函数和指标:

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()

使用 tf.GradientTape@tf.function 装饰器定义参数更新函数以加快速度

@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
      logits = model(x, training=True)
      loss_value = loss_fn(y, logits)
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
  train_acc_metric.update_state(y, logits)
  train_loss_metric.update_state(y, logits)
  return loss_value

@tf.function
def test_step(x, y):
  logits = model(x, training=False)
  val_acc_metric.update_state(y, logits)
  val_loss_metric.update_state(y, logits)

接下来,编写一个自定义训练循环,可以在其中手动实现提前停止规则。

下面的示例显示了当验证损失在一定数量的周期内没有改进时如何停止训练:

epochs = 100
patience = 5
wait = 0
best = 0

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
      loss_value = train_step(x_batch_train, y_batch_train)
      if step % 200 == 0:
        print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
        print("Seen so far: %s samples" % ((step + 1) * 128))        
    train_acc = train_acc_metric.result()
    train_loss = train_loss_metric.result()
    train_acc_metric.reset_states()
    train_loss_metric.reset_states()
    print("Training acc over epoch: %.4f" % (train_acc.numpy()))

    for x_batch_val, y_batch_val in ds_test:
      test_step(x_batch_val, y_batch_val)
    val_acc = val_acc_metric.result()
    val_loss = val_loss_metric.result()
    val_acc_metric.reset_states()
    val_loss_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

    # The early stopping strategy: stop the training if `val_loss` does not
    # decrease over a certain number of epochs.
    wait += 1
    if val_loss > best:
      best = val_loss
      wait = 0
    if wait >= patience:
      break
Start of epoch 0
Training loss at step 0: 2.3916
Seen so far: 128 samples
Training loss at step 200: 0.2400
Seen so far: 25728 samples
Training loss at step 400: 0.2299
Seen so far: 51328 samples
Training acc over epoch: 0.9309
Validation acc: 0.9617
Time taken: 1.64s

Start of epoch 1
Training loss at step 0: 0.0970
Seen so far: 128 samples
Training loss at step 200: 0.1743
Seen so far: 25728 samples
Training loss at step 400: 0.1372
Seen so far: 51328 samples
Training acc over epoch: 0.9693
Validation acc: 0.9685
Time taken: 1.05s

Start of epoch 2
Training loss at step 0: 0.0772
Seen so far: 128 samples
Training loss at step 200: 0.1270
Seen so far: 25728 samples
Training loss at step 400: 0.0727
Seen so far: 51328 samples
Training acc over epoch: 0.9783
Validation acc: 0.9702
Time taken: 1.01s

Start of epoch 3
Training loss at step 0: 0.0297
Seen so far: 128 samples
Training loss at step 200: 0.0619
Seen so far: 25728 samples
Training loss at step 400: 0.0624
Seen so far: 51328 samples
Training acc over epoch: 0.9836
Validation acc: 0.9720
Time taken: 0.99s

Start of epoch 4
Training loss at step 0: 0.0245
Seen so far: 128 samples
Training loss at step 200: 0.0324
Seen so far: 25728 samples
Training loss at step 400: 0.0662
Seen so far: 51328 samples
Training acc over epoch: 0.9867
Validation acc: 0.9722
Time taken: 1.02s

Start of epoch 5
Training loss at step 0: 0.0334
Seen so far: 128 samples
Training loss at step 200: 0.0492
Seen so far: 25728 samples
Training loss at step 400: 0.0469
Seen so far: 51328 samples
Training acc over epoch: 0.9875
Validation acc: 0.9714
Time taken: 1.00s

后续步骤