ترحيل التوقف المبكر

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر

يوضح هذا الكمبيوتر الدفتري كيف يمكنك إعداد تدريب النموذج مع التوقف المبكر ، أولاً ، في TensorFlow 1 مع tf.estimator.Estimator وخطاف التوقف المبكر ، ثم في TensorFlow 2 مع Keras APIs أو حلقة تدريب مخصصة. التوقف المبكر هو تقنية تنظيم توقف التدريب ، على سبيل المثال ، إذا وصل فقدان التحقق إلى حد معين.

في TensorFlow 2 ، توجد ثلاث طرق لتنفيذ الإيقاف المبكر:

يثبت

import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds

TensorFlow 1: التوقف المبكر باستخدام خطاف الإيقاف المبكر ومحدد tf

ابدأ بتحديد وظائف تحميل مجموعة بيانات MNIST والمعالجة المسبقة ، وتعريف النموذج لاستخدامه مع tf.estimator.Estimator .

def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def _input_fn():
  ds_train = tfds.load(
    name='mnist',
    split='train',
    shuffle_files=True,
    as_supervised=True)

  ds_train = ds_train.map(
      normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_train = ds_train.batch(128)
  ds_train = ds_train.repeat(100)
  return ds_train

def _eval_input_fn():
  ds_test = tfds.load(
    name='mnist',
    split='test',
    shuffle_files=True,
    as_supervised=True)
  ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_test = ds_test.batch(128)
  return ds_test

def _model_fn(features, labels, mode):
  flatten = tf1.layers.Flatten()(features)
  features = tf1.layers.Dense(128, 'relu')(flatten)
  logits = tf1.layers.Dense(10)(features)

  loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  optimizer = tf1.train.AdagradOptimizer(0.005)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())

  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

في TensorFlow 1 ، يعمل الإيقاف المبكر عن طريق إعداد خطاف إيقاف مبكر باستخدام tf.estimator.experimental.make_early_stopping_hook . يمكنك تمرير الخطاف إلى طريقة make_early_stopping_hook كمعامل لـ should_stop_fn ، والذي يمكنه قبول دالة بدون أي وسيطات. يتوقف التدريب مرة واحدة should_stop_fn بإرجاع True .

يوضح المثال التالي كيفية تنفيذ أسلوب التوقف المبكر الذي يحد من وقت التدريب بحد أقصى 20 ثانية:

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

start_time = time.time()
max_train_seconds = 20

def should_stop_fn():
  return time.time() - start_time > max_train_seconds

early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
    estimator=estimator,
    should_stop_fn=should_stop_fn,
    run_every_secs=1,
    run_every_steps=None)

train_spec = tf1.estimator.TrainSpec(
    input_fn=_input_fn,
    hooks=[early_stopping_hook])

eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)

tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpocmc6_bo
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpocmc6_bo', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.3545606, step = 0
INFO:tensorflow:loss = 2.3545606, step = 0
INFO:tensorflow:global_step/sec: 94.5711
INFO:tensorflow:global_step/sec: 94.5711
INFO:tensorflow:loss = 1.3383636, step = 100 (1.060 sec)
INFO:tensorflow:loss = 1.3383636, step = 100 (1.060 sec)
INFO:tensorflow:global_step/sec: 158.428
INFO:tensorflow:global_step/sec: 158.428
INFO:tensorflow:loss = 0.7937969, step = 200 (0.631 sec)
INFO:tensorflow:loss = 0.7937969, step = 200 (0.631 sec)
INFO:tensorflow:global_step/sec: 287.334
INFO:tensorflow:global_step/sec: 287.334
INFO:tensorflow:loss = 0.69060934, step = 300 (0.349 sec)
INFO:tensorflow:loss = 0.69060934, step = 300 (0.349 sec)
INFO:tensorflow:global_step/sec: 286.658
INFO:tensorflow:global_step/sec: 286.658
INFO:tensorflow:loss = 0.59314424, step = 400 (0.349 sec)
INFO:tensorflow:loss = 0.59314424, step = 400 (0.349 sec)
INFO:tensorflow:global_step/sec: 311.591
INFO:tensorflow:global_step/sec: 311.591
INFO:tensorflow:loss = 0.50495726, step = 500 (0.320 sec)
INFO:tensorflow:loss = 0.50495726, step = 500 (0.320 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 536 vs previous value: 536. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 536 vs previous value: 536. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 538.395
INFO:tensorflow:global_step/sec: 538.395
INFO:tensorflow:loss = 0.43083754, step = 600 (0.186 sec)
INFO:tensorflow:loss = 0.43083754, step = 600 (0.186 sec)
INFO:tensorflow:global_step/sec: 503.72
INFO:tensorflow:global_step/sec: 503.72
INFO:tensorflow:loss = 0.381118, step = 700 (0.198 sec)
INFO:tensorflow:loss = 0.381118, step = 700 (0.198 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 715 vs previous value: 715. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 715 vs previous value: 715. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 482.019
INFO:tensorflow:global_step/sec: 482.019
INFO:tensorflow:loss = 0.49349022, step = 800 (0.207 sec)
INFO:tensorflow:loss = 0.49349022, step = 800 (0.207 sec)
INFO:tensorflow:global_step/sec: 508.316
INFO:tensorflow:global_step/sec: 508.316
INFO:tensorflow:loss = 0.38730466, step = 900 (0.199 sec)
INFO:tensorflow:loss = 0.38730466, step = 900 (0.199 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 987 vs previous value: 987. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 987 vs previous value: 987. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 452.89
INFO:tensorflow:global_step/sec: 452.89
INFO:tensorflow:loss = 0.44916487, step = 1000 (0.219 sec)
INFO:tensorflow:loss = 0.44916487, step = 1000 (0.219 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1042 vs previous value: 1042. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1042 vs previous value: 1042. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 519.401
INFO:tensorflow:global_step/sec: 519.401
INFO:tensorflow:loss = 0.44320562, step = 1100 (0.192 sec)
INFO:tensorflow:loss = 0.44320562, step = 1100 (0.192 sec)
INFO:tensorflow:global_step/sec: 510.25
INFO:tensorflow:global_step/sec: 510.25
INFO:tensorflow:loss = 0.3758085, step = 1200 (0.196 sec)
INFO:tensorflow:loss = 0.3758085, step = 1200 (0.196 sec)
INFO:tensorflow:global_step/sec: 518.649
INFO:tensorflow:global_step/sec: 518.649
INFO:tensorflow:loss = 0.46760654, step = 1300 (0.193 sec)
INFO:tensorflow:loss = 0.46760654, step = 1300 (0.193 sec)
INFO:tensorflow:global_step/sec: 474.056
INFO:tensorflow:global_step/sec: 474.056
INFO:tensorflow:loss = 0.29544568, step = 1400 (0.211 sec)
INFO:tensorflow:loss = 0.29544568, step = 1400 (0.211 sec)
INFO:tensorflow:global_step/sec: 461.406
INFO:tensorflow:global_step/sec: 461.406
INFO:tensorflow:loss = 0.28616875, step = 1500 (0.217 sec)
INFO:tensorflow:loss = 0.28616875, step = 1500 (0.217 sec)
INFO:tensorflow:global_step/sec: 486.2
INFO:tensorflow:global_step/sec: 486.2
INFO:tensorflow:loss = 0.4114887, step = 1600 (0.206 sec)
INFO:tensorflow:loss = 0.4114887, step = 1600 (0.206 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1678 vs previous value: 1678. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1678 vs previous value: 1678. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 507.701
INFO:tensorflow:global_step/sec: 507.701
INFO:tensorflow:loss = 0.35298553, step = 1700 (0.197 sec)
INFO:tensorflow:loss = 0.35298553, step = 1700 (0.197 sec)
INFO:tensorflow:global_step/sec: 490.541
INFO:tensorflow:global_step/sec: 490.541
INFO:tensorflow:loss = 0.3363277, step = 1800 (0.204 sec)
INFO:tensorflow:loss = 0.3363277, step = 1800 (0.204 sec)
INFO:tensorflow:global_step/sec: 460.083
INFO:tensorflow:global_step/sec: 460.083
INFO:tensorflow:loss = 0.50634325, step = 1900 (0.217 sec)
INFO:tensorflow:loss = 0.50634325, step = 1900 (0.217 sec)
INFO:tensorflow:global_step/sec: 436.782
INFO:tensorflow:global_step/sec: 436.782
INFO:tensorflow:loss = 0.2063987, step = 2000 (0.229 sec)
INFO:tensorflow:loss = 0.2063987, step = 2000 (0.229 sec)
INFO:tensorflow:global_step/sec: 475.841
INFO:tensorflow:global_step/sec: 475.841
INFO:tensorflow:loss = 0.27246287, step = 2100 (0.210 sec)
INFO:tensorflow:loss = 0.27246287, step = 2100 (0.210 sec)
INFO:tensorflow:global_step/sec: 483.322
INFO:tensorflow:global_step/sec: 483.322
INFO:tensorflow:loss = 0.31674564, step = 2200 (0.207 sec)
INFO:tensorflow:loss = 0.31674564, step = 2200 (0.207 sec)
INFO:tensorflow:global_step/sec: 442.257
INFO:tensorflow:global_step/sec: 442.257
INFO:tensorflow:loss = 0.3334998, step = 2300 (0.226 sec)
INFO:tensorflow:loss = 0.3334998, step = 2300 (0.226 sec)
INFO:tensorflow:global_step/sec: 476.38
INFO:tensorflow:global_step/sec: 476.38
INFO:tensorflow:loss = 0.2549953, step = 2400 (0.210 sec)
INFO:tensorflow:loss = 0.2549953, step = 2400 (0.210 sec)
INFO:tensorflow:global_step/sec: 467.543
INFO:tensorflow:global_step/sec: 467.543
INFO:tensorflow:loss = 0.21111101, step = 2500 (0.214 sec)
INFO:tensorflow:loss = 0.21111101, step = 2500 (0.214 sec)
INFO:tensorflow:global_step/sec: 497.051
INFO:tensorflow:global_step/sec: 497.051
INFO:tensorflow:loss = 0.15878338, step = 2600 (0.201 sec)
INFO:tensorflow:loss = 0.15878338, step = 2600 (0.201 sec)
INFO:tensorflow:global_step/sec: 461.785
INFO:tensorflow:global_step/sec: 461.785
INFO:tensorflow:loss = 0.31587577, step = 2700 (0.219 sec)
INFO:tensorflow:loss = 0.31587577, step = 2700 (0.219 sec)
INFO:tensorflow:global_step/sec: 493.743
INFO:tensorflow:global_step/sec: 493.743
INFO:tensorflow:loss = 0.47478187, step = 2800 (0.200 sec)
INFO:tensorflow:loss = 0.47478187, step = 2800 (0.200 sec)
INFO:tensorflow:global_step/sec: 463.477
INFO:tensorflow:global_step/sec: 463.477
INFO:tensorflow:loss = 0.2499526, step = 2900 (0.216 sec)
INFO:tensorflow:loss = 0.2499526, step = 2900 (0.216 sec)
INFO:tensorflow:global_step/sec: 538.27
INFO:tensorflow:global_step/sec: 538.27
INFO:tensorflow:loss = 0.34210858, step = 3000 (0.186 sec)
INFO:tensorflow:loss = 0.34210858, step = 3000 (0.186 sec)
INFO:tensorflow:global_step/sec: 508.741
INFO:tensorflow:global_step/sec: 508.741
INFO:tensorflow:loss = 0.2128592, step = 3100 (0.197 sec)
INFO:tensorflow:loss = 0.2128592, step = 3100 (0.197 sec)
INFO:tensorflow:global_step/sec: 519.319
INFO:tensorflow:global_step/sec: 519.319
INFO:tensorflow:loss = 0.40954083, step = 3200 (0.192 sec)
INFO:tensorflow:loss = 0.40954083, step = 3200 (0.192 sec)
INFO:tensorflow:global_step/sec: 468.989
INFO:tensorflow:global_step/sec: 468.989
INFO:tensorflow:loss = 0.34270883, step = 3300 (0.213 sec)
INFO:tensorflow:loss = 0.34270883, step = 3300 (0.213 sec)
INFO:tensorflow:global_step/sec: 479.856
INFO:tensorflow:global_step/sec: 479.856
INFO:tensorflow:loss = 0.26599607, step = 3400 (0.209 sec)
INFO:tensorflow:loss = 0.26599607, step = 3400 (0.209 sec)
INFO:tensorflow:global_step/sec: 495.76
INFO:tensorflow:global_step/sec: 495.76
INFO:tensorflow:loss = 0.21713805, step = 3500 (0.201 sec)
INFO:tensorflow:loss = 0.21713805, step = 3500 (0.201 sec)
INFO:tensorflow:global_step/sec: 440.282
INFO:tensorflow:global_step/sec: 440.282
INFO:tensorflow:loss = 0.22268976, step = 3600 (0.228 sec)
INFO:tensorflow:loss = 0.22268976, step = 3600 (0.228 sec)
INFO:tensorflow:global_step/sec: 495.629
INFO:tensorflow:global_step/sec: 495.629
INFO:tensorflow:loss = 0.28974164, step = 3700 (0.201 sec)
INFO:tensorflow:loss = 0.28974164, step = 3700 (0.201 sec)
INFO:tensorflow:global_step/sec: 468.695
INFO:tensorflow:global_step/sec: 468.695
INFO:tensorflow:loss = 0.37919793, step = 3800 (0.214 sec)
INFO:tensorflow:loss = 0.37919793, step = 3800 (0.214 sec)
INFO:tensorflow:global_step/sec: 529.005
INFO:tensorflow:global_step/sec: 529.005
INFO:tensorflow:loss = 0.23738712, step = 3900 (0.189 sec)
INFO:tensorflow:loss = 0.23738712, step = 3900 (0.189 sec)
INFO:tensorflow:global_step/sec: 494.809
INFO:tensorflow:global_step/sec: 494.809
INFO:tensorflow:loss = 0.29650036, step = 4000 (0.204 sec)
INFO:tensorflow:loss = 0.29650036, step = 4000 (0.204 sec)
INFO:tensorflow:global_step/sec: 525.629
INFO:tensorflow:global_step/sec: 525.629
INFO:tensorflow:loss = 0.20826155, step = 4100 (0.188 sec)
INFO:tensorflow:loss = 0.20826155, step = 4100 (0.188 sec)
INFO:tensorflow:global_step/sec: 509.573
INFO:tensorflow:global_step/sec: 509.573
INFO:tensorflow:loss = 0.26417816, step = 4200 (0.196 sec)
INFO:tensorflow:loss = 0.26417816, step = 4200 (0.196 sec)
INFO:tensorflow:global_step/sec: 472.845
INFO:tensorflow:global_step/sec: 472.845
INFO:tensorflow:loss = 0.31241363, step = 4300 (0.212 sec)
INFO:tensorflow:loss = 0.31241363, step = 4300 (0.212 sec)
INFO:tensorflow:global_step/sec: 510.868
INFO:tensorflow:global_step/sec: 510.868
INFO:tensorflow:loss = 0.32773697, step = 4400 (0.195 sec)
INFO:tensorflow:loss = 0.32773697, step = 4400 (0.195 sec)
INFO:tensorflow:global_step/sec: 492.967
INFO:tensorflow:global_step/sec: 492.967
INFO:tensorflow:loss = 0.28609803, step = 4500 (0.203 sec)
INFO:tensorflow:loss = 0.28609803, step = 4500 (0.203 sec)
INFO:tensorflow:global_step/sec: 507.394
INFO:tensorflow:global_step/sec: 507.394
INFO:tensorflow:loss = 0.32142323, step = 4600 (0.197 sec)
INFO:tensorflow:loss = 0.32142323, step = 4600 (0.197 sec)
INFO:tensorflow:global_step/sec: 475.176
INFO:tensorflow:global_step/sec: 475.176
INFO:tensorflow:loss = 0.14882785, step = 4700 (0.211 sec)
INFO:tensorflow:loss = 0.14882785, step = 4700 (0.211 sec)
INFO:tensorflow:global_step/sec: 503.718
INFO:tensorflow:global_step/sec: 503.718
INFO:tensorflow:loss = 0.312344, step = 4800 (0.198 sec)
INFO:tensorflow:loss = 0.312344, step = 4800 (0.198 sec)
INFO:tensorflow:global_step/sec: 497.659
INFO:tensorflow:global_step/sec: 497.659
INFO:tensorflow:loss = 0.37370217, step = 4900 (0.201 sec)
INFO:tensorflow:loss = 0.37370217, step = 4900 (0.201 sec)
INFO:tensorflow:global_step/sec: 477.736
INFO:tensorflow:global_step/sec: 477.736
INFO:tensorflow:loss = 0.2663591, step = 5000 (0.209 sec)
INFO:tensorflow:loss = 0.2663591, step = 5000 (0.209 sec)
INFO:tensorflow:global_step/sec: 496.559
INFO:tensorflow:global_step/sec: 496.559
INFO:tensorflow:loss = 0.34745598, step = 5100 (0.202 sec)
INFO:tensorflow:loss = 0.34745598, step = 5100 (0.202 sec)
INFO:tensorflow:global_step/sec: 475.989
INFO:tensorflow:global_step/sec: 475.989
INFO:tensorflow:loss = 0.21809828, step = 5200 (0.210 sec)
INFO:tensorflow:loss = 0.21809828, step = 5200 (0.210 sec)
INFO:tensorflow:global_step/sec: 474.464
INFO:tensorflow:global_step/sec: 474.464
INFO:tensorflow:loss = 0.2474105, step = 5300 (0.211 sec)
INFO:tensorflow:loss = 0.2474105, step = 5300 (0.211 sec)
INFO:tensorflow:global_step/sec: 488.774
INFO:tensorflow:global_step/sec: 488.774
INFO:tensorflow:loss = 0.1611641, step = 5400 (0.204 sec)
INFO:tensorflow:loss = 0.1611641, step = 5400 (0.204 sec)
INFO:tensorflow:global_step/sec: 504.942
INFO:tensorflow:global_step/sec: 504.942
INFO:tensorflow:loss = 0.2306528, step = 5500 (0.198 sec)
INFO:tensorflow:loss = 0.2306528, step = 5500 (0.198 sec)
INFO:tensorflow:global_step/sec: 514.058
INFO:tensorflow:global_step/sec: 514.058
INFO:tensorflow:loss = 0.20716992, step = 5600 (0.195 sec)
INFO:tensorflow:loss = 0.20716992, step = 5600 (0.195 sec)
INFO:tensorflow:global_step/sec: 458.899
INFO:tensorflow:global_step/sec: 458.899
INFO:tensorflow:loss = 0.16730343, step = 5700 (0.217 sec)
INFO:tensorflow:loss = 0.16730343, step = 5700 (0.217 sec)
INFO:tensorflow:global_step/sec: 495.197
INFO:tensorflow:global_step/sec: 495.197
INFO:tensorflow:loss = 0.2906361, step = 5800 (0.202 sec)
INFO:tensorflow:loss = 0.2906361, step = 5800 (0.202 sec)
INFO:tensorflow:global_step/sec: 482.244
INFO:tensorflow:global_step/sec: 482.244
INFO:tensorflow:loss = 0.24669808, step = 5900 (0.207 sec)
INFO:tensorflow:loss = 0.24669808, step = 5900 (0.207 sec)
INFO:tensorflow:global_step/sec: 484.946
INFO:tensorflow:global_step/sec: 484.946
INFO:tensorflow:loss = 0.26403594, step = 6000 (0.207 sec)
INFO:tensorflow:loss = 0.26403594, step = 6000 (0.207 sec)
INFO:tensorflow:global_step/sec: 486.74
INFO:tensorflow:global_step/sec: 486.74
INFO:tensorflow:loss = 0.19804293, step = 6100 (0.206 sec)
INFO:tensorflow:loss = 0.19804293, step = 6100 (0.206 sec)
INFO:tensorflow:global_step/sec: 436.727
INFO:tensorflow:global_step/sec: 436.727
INFO:tensorflow:loss = 0.25344175, step = 6200 (0.229 sec)
INFO:tensorflow:loss = 0.25344175, step = 6200 (0.229 sec)
INFO:tensorflow:global_step/sec: 428.73
INFO:tensorflow:global_step/sec: 428.73
INFO:tensorflow:loss = 0.2430937, step = 6300 (0.232 sec)
INFO:tensorflow:loss = 0.2430937, step = 6300 (0.232 sec)
INFO:tensorflow:global_step/sec: 449.706
INFO:tensorflow:global_step/sec: 449.706
INFO:tensorflow:loss = 0.2842306, step = 6400 (0.222 sec)
INFO:tensorflow:loss = 0.2842306, step = 6400 (0.222 sec)
INFO:tensorflow:global_step/sec: 440.873
INFO:tensorflow:global_step/sec: 440.873
INFO:tensorflow:loss = 0.2641199, step = 6500 (0.227 sec)
INFO:tensorflow:loss = 0.2641199, step = 6500 (0.227 sec)
INFO:tensorflow:global_step/sec: 424.092
INFO:tensorflow:global_step/sec: 424.092
INFO:tensorflow:loss = 0.19028814, step = 6600 (0.237 sec)
INFO:tensorflow:loss = 0.19028814, step = 6600 (0.237 sec)
INFO:tensorflow:global_step/sec: 450.352
INFO:tensorflow:global_step/sec: 450.352
INFO:tensorflow:loss = 0.24667627, step = 6700 (0.221 sec)
INFO:tensorflow:loss = 0.24667627, step = 6700 (0.221 sec)
INFO:tensorflow:global_step/sec: 462.774
INFO:tensorflow:global_step/sec: 462.774
INFO:tensorflow:loss = 0.40046322, step = 6800 (0.216 sec)
INFO:tensorflow:loss = 0.40046322, step = 6800 (0.216 sec)
INFO:tensorflow:global_step/sec: 460.854
INFO:tensorflow:global_step/sec: 460.854
INFO:tensorflow:loss = 0.14105138, step = 6900 (0.217 sec)
INFO:tensorflow:loss = 0.14105138, step = 6900 (0.217 sec)
INFO:tensorflow:Requesting early stopping at global step 6916
INFO:tensorflow:Requesting early stopping at global step 6916
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6917...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6917...
INFO:tensorflow:Saving checkpoints for 6917 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Saving checkpoints for 6917 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6917...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6917...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-09-22T20:07:35
INFO:tensorflow:Starting evaluation at 2021-09-22T20:07:35
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Restoring parameters from /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Inference Time : 0.79520s
INFO:tensorflow:Inference Time : 0.79520s
INFO:tensorflow:Finished evaluation at 2021-09-22-20:07:36
INFO:tensorflow:Finished evaluation at 2021-09-22-20:07:36
INFO:tensorflow:Saving dict for global step 6917: global_step = 6917, loss = 0.227278
INFO:tensorflow:Saving dict for global step 6917: global_step = 6917, loss = 0.227278
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6917: /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6917: /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Loss for final step: 0.13882703.
INFO:tensorflow:Loss for final step: 0.13882703.
({'loss': 0.227278, 'global_step': 6917}, [])

TensorFlow 2: التوقف المبكر من خلال رد اتصال مدمج و Model.fit

قم بإعداد مجموعة بيانات MNIST ونموذج Keras البسيط:

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.005),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

في TensorFlow 2 ، عند استخدام Keras Model.fit (أو Model.evaluate ) ، يمكنك تكوين إيقاف مبكر عن طريق تمرير رد اتصال مدمج - tf.keras.callbacks.EarlyStopping - إلى معلمة callbacks الخاصة بالنموذج Model.fit .

يراقب رد الاتصال EarlyStopping مقياسًا محددًا من قبل المستخدم وينهي التدريب عندما يتوقف عن التحسن. (تحقق من التدريب والتقييم باستخدام الأساليب المضمنة أو مستندات API لمزيد من المعلومات.)

يوجد أدناه مثال على رد الاتصال بالتوقف المبكر الذي يراقب الخسارة ويوقف التدريب بعد عدد الفترات التي لا تظهر أي تحسينات مضبوطة على 3 ( patience ):

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)

# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)

len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 5s 8ms/step - loss: 0.2371 - sparse_categorical_accuracy: 0.9293 - val_loss: 0.1334 - val_sparse_categorical_accuracy: 0.9611
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.1028 - sparse_categorical_accuracy: 0.9686 - val_loss: 0.1062 - val_sparse_categorical_accuracy: 0.9667
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0703 - sparse_categorical_accuracy: 0.9783 - val_loss: 0.0993 - val_sparse_categorical_accuracy: 0.9707
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0552 - sparse_categorical_accuracy: 0.9822 - val_loss: 0.1040 - val_sparse_categorical_accuracy: 0.9680
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0420 - sparse_categorical_accuracy: 0.9865 - val_loss: 0.1033 - val_sparse_categorical_accuracy: 0.9716
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0387 - sparse_categorical_accuracy: 0.9871 - val_loss: 0.1167 - val_sparse_categorical_accuracy: 0.9691
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0321 - sparse_categorical_accuracy: 0.9893 - val_loss: 0.1396 - val_sparse_categorical_accuracy: 0.9672
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0285 - sparse_categorical_accuracy: 0.9902 - val_loss: 0.1397 - val_sparse_categorical_accuracy: 0.9671
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0263 - sparse_categorical_accuracy: 0.9915 - val_loss: 0.1296 - val_sparse_categorical_accuracy: 0.9715
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0250 - sparse_categorical_accuracy: 0.9915 - val_loss: 0.1440 - val_sparse_categorical_accuracy: 0.9715
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0274 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.1439 - val_sparse_categorical_accuracy: 0.9710
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0241 - sparse_categorical_accuracy: 0.9923 - val_loss: 0.1429 - val_sparse_categorical_accuracy: 0.9718
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0205 - sparse_categorical_accuracy: 0.9929 - val_loss: 0.1451 - val_sparse_categorical_accuracy: 0.9753
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0196 - sparse_categorical_accuracy: 0.9936 - val_loss: 0.1562 - val_sparse_categorical_accuracy: 0.9750
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0214 - sparse_categorical_accuracy: 0.9930 - val_loss: 0.1531 - val_sparse_categorical_accuracy: 0.9748
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0178 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1712 - val_sparse_categorical_accuracy: 0.9731
Epoch 17/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0177 - sparse_categorical_accuracy: 0.9947 - val_loss: 0.1715 - val_sparse_categorical_accuracy: 0.9755
Epoch 18/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1826 - val_sparse_categorical_accuracy: 0.9730
Epoch 19/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0188 - sparse_categorical_accuracy: 0.9942 - val_loss: 0.1919 - val_sparse_categorical_accuracy: 0.9732
Epoch 20/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0190 - sparse_categorical_accuracy: 0.9944 - val_loss: 0.1703 - val_sparse_categorical_accuracy: 0.9777
Epoch 21/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0153 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.1725 - val_sparse_categorical_accuracy: 0.9764
21

TensorFlow 2: الإيقاف المبكر باستخدام رد اتصال مخصص و Model.fit

يمكنك أيضًا تنفيذ رد نداء مخصص للإيقاف المبكر ، والذي يمكن أيضًا تمريره إلى معلمة callbacks من Model.fit (أو Model.evaluate ).

في هذا المثال ، يتم إيقاف عملية التدريب بمجرد تعيين self.model.stop_training ليكون True :

class LimitTrainingTime(tf.keras.callbacks.Callback):
  def __init__(self, max_time_s):
    super().__init__()
    self.max_time_s = max_time_s
    self.start_time = None

  def on_train_begin(self, logs):
    self.start_time = time.time()

  def on_train_batch_end(self, batch, logs):
    now = time.time()
    if now - self.start_time >  self.max_time_s:
      self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0131 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.1911 - val_sparse_categorical_accuracy: 0.9749
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0133 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1999 - val_sparse_categorical_accuracy: 0.9755
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0153 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1927 - val_sparse_categorical_accuracy: 0.9770
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0145 - sparse_categorical_accuracy: 0.9957 - val_loss: 0.2279 - val_sparse_categorical_accuracy: 0.9753
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2272 - val_sparse_categorical_accuracy: 0.9755
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0132 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2352 - val_sparse_categorical_accuracy: 0.9747
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0144 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.2421 - val_sparse_categorical_accuracy: 0.9734
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.2260 - val_sparse_categorical_accuracy: 0.9785
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.2472 - val_sparse_categorical_accuracy: 0.9752
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.2166 - val_sparse_categorical_accuracy: 0.9768
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0145 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2289 - val_sparse_categorical_accuracy: 0.9781
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0119 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.2310 - val_sparse_categorical_accuracy: 0.9777
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0144 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2617 - val_sparse_categorical_accuracy: 0.9781
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0119 - sparse_categorical_accuracy: 0.9972 - val_loss: 0.3007 - val_sparse_categorical_accuracy: 0.9754
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0150 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.3014 - val_sparse_categorical_accuracy: 0.9767
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2815 - val_sparse_categorical_accuracy: 0.9750
Epoch 17/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2606 - val_sparse_categorical_accuracy: 0.9765
Epoch 18/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0103 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.2602 - val_sparse_categorical_accuracy: 0.9777
Epoch 19/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0098 - sparse_categorical_accuracy: 0.9979 - val_loss: 0.2594 - val_sparse_categorical_accuracy: 0.9780
Epoch 20/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.3008 - val_sparse_categorical_accuracy: 0.9755
Epoch 21/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0110 - sparse_categorical_accuracy: 0.9974 - val_loss: 0.2662 - val_sparse_categorical_accuracy: 0.9765
Epoch 22/100
469/469 [==============================] - 1s 1ms/step - loss: 0.0083 - sparse_categorical_accuracy: 0.9978 - val_loss: 0.2587 - val_sparse_categorical_accuracy: 0.9797
22

TensorFlow 2: التوقف المبكر بحلقة تدريب مخصصة

في TensorFlow 2 ، يمكنك تنفيذ التوقف المبكر في حلقة تدريب مخصصة إذا كنت لا تقوم بالتدريب والتقييم باستخدام أساليب Keras المضمنة .

ابدأ باستخدام Keras APIs لتحديد نموذج بسيط آخر ومحسن ووظيفة خسارة ومقاييس:

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()

حدد وظائف تحديث المعلمة باستخدام tf.GradientTape و @tf.function decorator للتسريع:

@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
      logits = model(x, training=True)
      loss_value = loss_fn(y, logits)
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
  train_acc_metric.update_state(y, logits)
  train_loss_metric.update_state(y, logits)
  return loss_value

@tf.function
def test_step(x, y):
  logits = model(x, training=False)
  val_acc_metric.update_state(y, logits)
  val_loss_metric.update_state(y, logits)

بعد ذلك ، اكتب حلقة تدريب مخصصة ، حيث يمكنك تنفيذ قاعدة التوقف المبكر يدويًا.

يوضح المثال أدناه كيفية إيقاف التدريب عندما لا يتحسن فقدان التحقق خلال عدد معين من الفترات:

epochs = 100
patience = 5
wait = 0
best = 0

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
      loss_value = train_step(x_batch_train, y_batch_train)
      if step % 200 == 0:
        print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
        print("Seen so far: %s samples" % ((step + 1) * 128))        
    train_acc = train_acc_metric.result()
    train_loss = train_loss_metric.result()
    train_acc_metric.reset_states()
    train_loss_metric.reset_states()
    print("Training acc over epoch: %.4f" % (train_acc.numpy()))

    for x_batch_val, y_batch_val in ds_test:
      test_step(x_batch_val, y_batch_val)
    val_acc = val_acc_metric.result()
    val_loss = val_loss_metric.result()
    val_acc_metric.reset_states()
    val_loss_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

    # The early stopping strategy: stop the training if `val_loss` does not
    # decrease over a certain number of epochs.
    wait += 1
    if val_loss > best:
      best = val_loss
      wait = 0
    if wait >= patience:
      break
Start of epoch 0
Training loss at step 0: 2.3073
Seen so far: 128 samples
Training loss at step 200: 0.2164
Seen so far: 25728 samples
Training loss at step 400: 0.2186
Seen so far: 51328 samples
Training acc over epoch: 0.9321
Validation acc: 0.9644
Time taken: 1.73s

Start of epoch 1
Training loss at step 0: 0.0733
Seen so far: 128 samples
Training loss at step 200: 0.1581
Seen so far: 25728 samples
Training loss at step 400: 0.1625
Seen so far: 51328 samples
Training acc over epoch: 0.9704
Validation acc: 0.9681
Time taken: 1.23s

Start of epoch 2
Training loss at step 0: 0.0501
Seen so far: 128 samples
Training loss at step 200: 0.1389
Seen so far: 25728 samples
Training loss at step 400: 0.1495
Seen so far: 51328 samples
Training acc over epoch: 0.9779
Validation acc: 0.9703
Time taken: 1.17s

Start of epoch 3
Training loss at step 0: 0.0513
Seen so far: 128 samples
Training loss at step 200: 0.0638
Seen so far: 25728 samples
Training loss at step 400: 0.0930
Seen so far: 51328 samples
Training acc over epoch: 0.9830
Validation acc: 0.9719
Time taken: 1.20s

Start of epoch 4
Training loss at step 0: 0.0251
Seen so far: 128 samples
Training loss at step 200: 0.0482
Seen so far: 25728 samples
Training loss at step 400: 0.0872
Seen so far: 51328 samples
Training acc over epoch: 0.9849
Validation acc: 0.9672
Time taken: 1.18s

Start of epoch 5
Training loss at step 0: 0.0417
Seen so far: 128 samples
Training loss at step 200: 0.0302
Seen so far: 25728 samples
Training loss at step 400: 0.0362
Seen so far: 51328 samples
Training acc over epoch: 0.9878
Validation acc: 0.9703
Time taken: 1.21s

الخطوات التالية