Veja no TensorFlow.org | Executar no Google Colab | Ver fonte no GitHub | Baixar caderno |
Este notebook demonstra como você pode configurar o treinamento de modelo com interrupção antecipada, primeiro no TensorFlow 1 com tf.estimator.Estimator e um gancho de interrupção antecipada e, em seguida, no TensorFlow 2 com APIs Keras ou um loop de treinamento personalizado. A parada antecipada é uma técnica de regularização que interrompe o treinamento se, por exemplo, a perda de validação atingir um determinado limite.
No TensorFlow 2, há três maneiras de implementar a parada antecipada:
- Use um retorno de chamada Keras integrado —
tf.keras.callbacks.EarlyStopping— e passe-o paraModel.fit. - Defina um retorno de chamada personalizado e passe-o para Keras
Model.fit. - Escreva uma regra de parada antecipada personalizada em um loop de treinamento personalizado (com
tf.GradientTape).
Configurar
import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
TensorFlow 1: Parada antecipada com gancho de parada antecipada e tf.estimator
Comece definindo funções para carregamento e pré-processamento do conjunto de dados MNIST e definição de modelo a ser usado com tf.estimator.Estimator :
def normalize_img(image, label):
return tf.cast(image, tf.float32) / 255., label
def _input_fn():
ds_train = tfds.load(
name='mnist',
split='train',
shuffle_files=True,
as_supervised=True)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_train = ds_train.repeat(100)
return ds_train
def _eval_input_fn():
ds_test = tfds.load(
name='mnist',
split='test',
shuffle_files=True,
as_supervised=True)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
return ds_test
def _model_fn(features, labels, mode):
flatten = tf1.layers.Flatten()(features)
features = tf1.layers.Dense(128, 'relu')(flatten)
logits = tf1.layers.Dense(10)(features)
loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
optimizer = tf1.train.AdagradOptimizer(0.005)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
No TensorFlow 1, a parada antecipada funciona configurando um gancho de parada antecipada com tf.estimator.experimental.make_early_stopping_hook . Você passa o gancho para o método make_early_stopping_hook como um parâmetro para should_stop_fn , que pode aceitar uma função sem nenhum argumento. O treinamento para quando should_stop_fn retornar True .
O exemplo a seguir demonstra como implementar uma técnica de parada antecipada que limita o tempo de treinamento a um máximo de 20 segundos:
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
start_time = time.time()
max_train_seconds = 20
def should_stop_fn():
return time.time() - start_time > max_train_seconds
early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
estimator=estimator,
should_stop_fn=should_stop_fn,
run_every_secs=1,
run_every_steps=None)
train_spec = tf1.estimator.TrainSpec(
input_fn=_input_fn,
hooks=[early_stopping_hook])
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpocmc6_bo
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpocmc6_bo', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.3545606, step = 0
INFO:tensorflow:loss = 2.3545606, step = 0
INFO:tensorflow:global_step/sec: 94.5711
INFO:tensorflow:global_step/sec: 94.5711
INFO:tensorflow:loss = 1.3383636, step = 100 (1.060 sec)
INFO:tensorflow:loss = 1.3383636, step = 100 (1.060 sec)
INFO:tensorflow:global_step/sec: 158.428
INFO:tensorflow:global_step/sec: 158.428
INFO:tensorflow:loss = 0.7937969, step = 200 (0.631 sec)
INFO:tensorflow:loss = 0.7937969, step = 200 (0.631 sec)
INFO:tensorflow:global_step/sec: 287.334
INFO:tensorflow:global_step/sec: 287.334
INFO:tensorflow:loss = 0.69060934, step = 300 (0.349 sec)
INFO:tensorflow:loss = 0.69060934, step = 300 (0.349 sec)
INFO:tensorflow:global_step/sec: 286.658
INFO:tensorflow:global_step/sec: 286.658
INFO:tensorflow:loss = 0.59314424, step = 400 (0.349 sec)
INFO:tensorflow:loss = 0.59314424, step = 400 (0.349 sec)
INFO:tensorflow:global_step/sec: 311.591
INFO:tensorflow:global_step/sec: 311.591
INFO:tensorflow:loss = 0.50495726, step = 500 (0.320 sec)
INFO:tensorflow:loss = 0.50495726, step = 500 (0.320 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 536 vs previous value: 536. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 536 vs previous value: 536. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 538.395
INFO:tensorflow:global_step/sec: 538.395
INFO:tensorflow:loss = 0.43083754, step = 600 (0.186 sec)
INFO:tensorflow:loss = 0.43083754, step = 600 (0.186 sec)
INFO:tensorflow:global_step/sec: 503.72
INFO:tensorflow:global_step/sec: 503.72
INFO:tensorflow:loss = 0.381118, step = 700 (0.198 sec)
INFO:tensorflow:loss = 0.381118, step = 700 (0.198 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 715 vs previous value: 715. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 715 vs previous value: 715. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 482.019
INFO:tensorflow:global_step/sec: 482.019
INFO:tensorflow:loss = 0.49349022, step = 800 (0.207 sec)
INFO:tensorflow:loss = 0.49349022, step = 800 (0.207 sec)
INFO:tensorflow:global_step/sec: 508.316
INFO:tensorflow:global_step/sec: 508.316
INFO:tensorflow:loss = 0.38730466, step = 900 (0.199 sec)
INFO:tensorflow:loss = 0.38730466, step = 900 (0.199 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 987 vs previous value: 987. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 987 vs previous value: 987. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 452.89
INFO:tensorflow:global_step/sec: 452.89
INFO:tensorflow:loss = 0.44916487, step = 1000 (0.219 sec)
INFO:tensorflow:loss = 0.44916487, step = 1000 (0.219 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1042 vs previous value: 1042. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1042 vs previous value: 1042. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 519.401
INFO:tensorflow:global_step/sec: 519.401
INFO:tensorflow:loss = 0.44320562, step = 1100 (0.192 sec)
INFO:tensorflow:loss = 0.44320562, step = 1100 (0.192 sec)
INFO:tensorflow:global_step/sec: 510.25
INFO:tensorflow:global_step/sec: 510.25
INFO:tensorflow:loss = 0.3758085, step = 1200 (0.196 sec)
INFO:tensorflow:loss = 0.3758085, step = 1200 (0.196 sec)
INFO:tensorflow:global_step/sec: 518.649
INFO:tensorflow:global_step/sec: 518.649
INFO:tensorflow:loss = 0.46760654, step = 1300 (0.193 sec)
INFO:tensorflow:loss = 0.46760654, step = 1300 (0.193 sec)
INFO:tensorflow:global_step/sec: 474.056
INFO:tensorflow:global_step/sec: 474.056
INFO:tensorflow:loss = 0.29544568, step = 1400 (0.211 sec)
INFO:tensorflow:loss = 0.29544568, step = 1400 (0.211 sec)
INFO:tensorflow:global_step/sec: 461.406
INFO:tensorflow:global_step/sec: 461.406
INFO:tensorflow:loss = 0.28616875, step = 1500 (0.217 sec)
INFO:tensorflow:loss = 0.28616875, step = 1500 (0.217 sec)
INFO:tensorflow:global_step/sec: 486.2
INFO:tensorflow:global_step/sec: 486.2
INFO:tensorflow:loss = 0.4114887, step = 1600 (0.206 sec)
INFO:tensorflow:loss = 0.4114887, step = 1600 (0.206 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1678 vs previous value: 1678. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1678 vs previous value: 1678. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 507.701
INFO:tensorflow:global_step/sec: 507.701
INFO:tensorflow:loss = 0.35298553, step = 1700 (0.197 sec)
INFO:tensorflow:loss = 0.35298553, step = 1700 (0.197 sec)
INFO:tensorflow:global_step/sec: 490.541
INFO:tensorflow:global_step/sec: 490.541
INFO:tensorflow:loss = 0.3363277, step = 1800 (0.204 sec)
INFO:tensorflow:loss = 0.3363277, step = 1800 (0.204 sec)
INFO:tensorflow:global_step/sec: 460.083
INFO:tensorflow:global_step/sec: 460.083
INFO:tensorflow:loss = 0.50634325, step = 1900 (0.217 sec)
INFO:tensorflow:loss = 0.50634325, step = 1900 (0.217 sec)
INFO:tensorflow:global_step/sec: 436.782
INFO:tensorflow:global_step/sec: 436.782
INFO:tensorflow:loss = 0.2063987, step = 2000 (0.229 sec)
INFO:tensorflow:loss = 0.2063987, step = 2000 (0.229 sec)
INFO:tensorflow:global_step/sec: 475.841
INFO:tensorflow:global_step/sec: 475.841
INFO:tensorflow:loss = 0.27246287, step = 2100 (0.210 sec)
INFO:tensorflow:loss = 0.27246287, step = 2100 (0.210 sec)
INFO:tensorflow:global_step/sec: 483.322
INFO:tensorflow:global_step/sec: 483.322
INFO:tensorflow:loss = 0.31674564, step = 2200 (0.207 sec)
INFO:tensorflow:loss = 0.31674564, step = 2200 (0.207 sec)
INFO:tensorflow:global_step/sec: 442.257
INFO:tensorflow:global_step/sec: 442.257
INFO:tensorflow:loss = 0.3334998, step = 2300 (0.226 sec)
INFO:tensorflow:loss = 0.3334998, step = 2300 (0.226 sec)
INFO:tensorflow:global_step/sec: 476.38
INFO:tensorflow:global_step/sec: 476.38
INFO:tensorflow:loss = 0.2549953, step = 2400 (0.210 sec)
INFO:tensorflow:loss = 0.2549953, step = 2400 (0.210 sec)
INFO:tensorflow:global_step/sec: 467.543
INFO:tensorflow:global_step/sec: 467.543
INFO:tensorflow:loss = 0.21111101, step = 2500 (0.214 sec)
INFO:tensorflow:loss = 0.21111101, step = 2500 (0.214 sec)
INFO:tensorflow:global_step/sec: 497.051
INFO:tensorflow:global_step/sec: 497.051
INFO:tensorflow:loss = 0.15878338, step = 2600 (0.201 sec)
INFO:tensorflow:loss = 0.15878338, step = 2600 (0.201 sec)
INFO:tensorflow:global_step/sec: 461.785
INFO:tensorflow:global_step/sec: 461.785
INFO:tensorflow:loss = 0.31587577, step = 2700 (0.219 sec)
INFO:tensorflow:loss = 0.31587577, step = 2700 (0.219 sec)
INFO:tensorflow:global_step/sec: 493.743
INFO:tensorflow:global_step/sec: 493.743
INFO:tensorflow:loss = 0.47478187, step = 2800 (0.200 sec)
INFO:tensorflow:loss = 0.47478187, step = 2800 (0.200 sec)
INFO:tensorflow:global_step/sec: 463.477
INFO:tensorflow:global_step/sec: 463.477
INFO:tensorflow:loss = 0.2499526, step = 2900 (0.216 sec)
INFO:tensorflow:loss = 0.2499526, step = 2900 (0.216 sec)
INFO:tensorflow:global_step/sec: 538.27
INFO:tensorflow:global_step/sec: 538.27
INFO:tensorflow:loss = 0.34210858, step = 3000 (0.186 sec)
INFO:tensorflow:loss = 0.34210858, step = 3000 (0.186 sec)
INFO:tensorflow:global_step/sec: 508.741
INFO:tensorflow:global_step/sec: 508.741
INFO:tensorflow:loss = 0.2128592, step = 3100 (0.197 sec)
INFO:tensorflow:loss = 0.2128592, step = 3100 (0.197 sec)
INFO:tensorflow:global_step/sec: 519.319
INFO:tensorflow:global_step/sec: 519.319
INFO:tensorflow:loss = 0.40954083, step = 3200 (0.192 sec)
INFO:tensorflow:loss = 0.40954083, step = 3200 (0.192 sec)
INFO:tensorflow:global_step/sec: 468.989
INFO:tensorflow:global_step/sec: 468.989
INFO:tensorflow:loss = 0.34270883, step = 3300 (0.213 sec)
INFO:tensorflow:loss = 0.34270883, step = 3300 (0.213 sec)
INFO:tensorflow:global_step/sec: 479.856
INFO:tensorflow:global_step/sec: 479.856
INFO:tensorflow:loss = 0.26599607, step = 3400 (0.209 sec)
INFO:tensorflow:loss = 0.26599607, step = 3400 (0.209 sec)
INFO:tensorflow:global_step/sec: 495.76
INFO:tensorflow:global_step/sec: 495.76
INFO:tensorflow:loss = 0.21713805, step = 3500 (0.201 sec)
INFO:tensorflow:loss = 0.21713805, step = 3500 (0.201 sec)
INFO:tensorflow:global_step/sec: 440.282
INFO:tensorflow:global_step/sec: 440.282
INFO:tensorflow:loss = 0.22268976, step = 3600 (0.228 sec)
INFO:tensorflow:loss = 0.22268976, step = 3600 (0.228 sec)
INFO:tensorflow:global_step/sec: 495.629
INFO:tensorflow:global_step/sec: 495.629
INFO:tensorflow:loss = 0.28974164, step = 3700 (0.201 sec)
INFO:tensorflow:loss = 0.28974164, step = 3700 (0.201 sec)
INFO:tensorflow:global_step/sec: 468.695
INFO:tensorflow:global_step/sec: 468.695
INFO:tensorflow:loss = 0.37919793, step = 3800 (0.214 sec)
INFO:tensorflow:loss = 0.37919793, step = 3800 (0.214 sec)
INFO:tensorflow:global_step/sec: 529.005
INFO:tensorflow:global_step/sec: 529.005
INFO:tensorflow:loss = 0.23738712, step = 3900 (0.189 sec)
INFO:tensorflow:loss = 0.23738712, step = 3900 (0.189 sec)
INFO:tensorflow:global_step/sec: 494.809
INFO:tensorflow:global_step/sec: 494.809
INFO:tensorflow:loss = 0.29650036, step = 4000 (0.204 sec)
INFO:tensorflow:loss = 0.29650036, step = 4000 (0.204 sec)
INFO:tensorflow:global_step/sec: 525.629
INFO:tensorflow:global_step/sec: 525.629
INFO:tensorflow:loss = 0.20826155, step = 4100 (0.188 sec)
INFO:tensorflow:loss = 0.20826155, step = 4100 (0.188 sec)
INFO:tensorflow:global_step/sec: 509.573
INFO:tensorflow:global_step/sec: 509.573
INFO:tensorflow:loss = 0.26417816, step = 4200 (0.196 sec)
INFO:tensorflow:loss = 0.26417816, step = 4200 (0.196 sec)
INFO:tensorflow:global_step/sec: 472.845
INFO:tensorflow:global_step/sec: 472.845
INFO:tensorflow:loss = 0.31241363, step = 4300 (0.212 sec)
INFO:tensorflow:loss = 0.31241363, step = 4300 (0.212 sec)
INFO:tensorflow:global_step/sec: 510.868
INFO:tensorflow:global_step/sec: 510.868
INFO:tensorflow:loss = 0.32773697, step = 4400 (0.195 sec)
INFO:tensorflow:loss = 0.32773697, step = 4400 (0.195 sec)
INFO:tensorflow:global_step/sec: 492.967
INFO:tensorflow:global_step/sec: 492.967
INFO:tensorflow:loss = 0.28609803, step = 4500 (0.203 sec)
INFO:tensorflow:loss = 0.28609803, step = 4500 (0.203 sec)
INFO:tensorflow:global_step/sec: 507.394
INFO:tensorflow:global_step/sec: 507.394
INFO:tensorflow:loss = 0.32142323, step = 4600 (0.197 sec)
INFO:tensorflow:loss = 0.32142323, step = 4600 (0.197 sec)
INFO:tensorflow:global_step/sec: 475.176
INFO:tensorflow:global_step/sec: 475.176
INFO:tensorflow:loss = 0.14882785, step = 4700 (0.211 sec)
INFO:tensorflow:loss = 0.14882785, step = 4700 (0.211 sec)
INFO:tensorflow:global_step/sec: 503.718
INFO:tensorflow:global_step/sec: 503.718
INFO:tensorflow:loss = 0.312344, step = 4800 (0.198 sec)
INFO:tensorflow:loss = 0.312344, step = 4800 (0.198 sec)
INFO:tensorflow:global_step/sec: 497.659
INFO:tensorflow:global_step/sec: 497.659
INFO:tensorflow:loss = 0.37370217, step = 4900 (0.201 sec)
INFO:tensorflow:loss = 0.37370217, step = 4900 (0.201 sec)
INFO:tensorflow:global_step/sec: 477.736
INFO:tensorflow:global_step/sec: 477.736
INFO:tensorflow:loss = 0.2663591, step = 5000 (0.209 sec)
INFO:tensorflow:loss = 0.2663591, step = 5000 (0.209 sec)
INFO:tensorflow:global_step/sec: 496.559
INFO:tensorflow:global_step/sec: 496.559
INFO:tensorflow:loss = 0.34745598, step = 5100 (0.202 sec)
INFO:tensorflow:loss = 0.34745598, step = 5100 (0.202 sec)
INFO:tensorflow:global_step/sec: 475.989
INFO:tensorflow:global_step/sec: 475.989
INFO:tensorflow:loss = 0.21809828, step = 5200 (0.210 sec)
INFO:tensorflow:loss = 0.21809828, step = 5200 (0.210 sec)
INFO:tensorflow:global_step/sec: 474.464
INFO:tensorflow:global_step/sec: 474.464
INFO:tensorflow:loss = 0.2474105, step = 5300 (0.211 sec)
INFO:tensorflow:loss = 0.2474105, step = 5300 (0.211 sec)
INFO:tensorflow:global_step/sec: 488.774
INFO:tensorflow:global_step/sec: 488.774
INFO:tensorflow:loss = 0.1611641, step = 5400 (0.204 sec)
INFO:tensorflow:loss = 0.1611641, step = 5400 (0.204 sec)
INFO:tensorflow:global_step/sec: 504.942
INFO:tensorflow:global_step/sec: 504.942
INFO:tensorflow:loss = 0.2306528, step = 5500 (0.198 sec)
INFO:tensorflow:loss = 0.2306528, step = 5500 (0.198 sec)
INFO:tensorflow:global_step/sec: 514.058
INFO:tensorflow:global_step/sec: 514.058
INFO:tensorflow:loss = 0.20716992, step = 5600 (0.195 sec)
INFO:tensorflow:loss = 0.20716992, step = 5600 (0.195 sec)
INFO:tensorflow:global_step/sec: 458.899
INFO:tensorflow:global_step/sec: 458.899
INFO:tensorflow:loss = 0.16730343, step = 5700 (0.217 sec)
INFO:tensorflow:loss = 0.16730343, step = 5700 (0.217 sec)
INFO:tensorflow:global_step/sec: 495.197
INFO:tensorflow:global_step/sec: 495.197
INFO:tensorflow:loss = 0.2906361, step = 5800 (0.202 sec)
INFO:tensorflow:loss = 0.2906361, step = 5800 (0.202 sec)
INFO:tensorflow:global_step/sec: 482.244
INFO:tensorflow:global_step/sec: 482.244
INFO:tensorflow:loss = 0.24669808, step = 5900 (0.207 sec)
INFO:tensorflow:loss = 0.24669808, step = 5900 (0.207 sec)
INFO:tensorflow:global_step/sec: 484.946
INFO:tensorflow:global_step/sec: 484.946
INFO:tensorflow:loss = 0.26403594, step = 6000 (0.207 sec)
INFO:tensorflow:loss = 0.26403594, step = 6000 (0.207 sec)
INFO:tensorflow:global_step/sec: 486.74
INFO:tensorflow:global_step/sec: 486.74
INFO:tensorflow:loss = 0.19804293, step = 6100 (0.206 sec)
INFO:tensorflow:loss = 0.19804293, step = 6100 (0.206 sec)
INFO:tensorflow:global_step/sec: 436.727
INFO:tensorflow:global_step/sec: 436.727
INFO:tensorflow:loss = 0.25344175, step = 6200 (0.229 sec)
INFO:tensorflow:loss = 0.25344175, step = 6200 (0.229 sec)
INFO:tensorflow:global_step/sec: 428.73
INFO:tensorflow:global_step/sec: 428.73
INFO:tensorflow:loss = 0.2430937, step = 6300 (0.232 sec)
INFO:tensorflow:loss = 0.2430937, step = 6300 (0.232 sec)
INFO:tensorflow:global_step/sec: 449.706
INFO:tensorflow:global_step/sec: 449.706
INFO:tensorflow:loss = 0.2842306, step = 6400 (0.222 sec)
INFO:tensorflow:loss = 0.2842306, step = 6400 (0.222 sec)
INFO:tensorflow:global_step/sec: 440.873
INFO:tensorflow:global_step/sec: 440.873
INFO:tensorflow:loss = 0.2641199, step = 6500 (0.227 sec)
INFO:tensorflow:loss = 0.2641199, step = 6500 (0.227 sec)
INFO:tensorflow:global_step/sec: 424.092
INFO:tensorflow:global_step/sec: 424.092
INFO:tensorflow:loss = 0.19028814, step = 6600 (0.237 sec)
INFO:tensorflow:loss = 0.19028814, step = 6600 (0.237 sec)
INFO:tensorflow:global_step/sec: 450.352
INFO:tensorflow:global_step/sec: 450.352
INFO:tensorflow:loss = 0.24667627, step = 6700 (0.221 sec)
INFO:tensorflow:loss = 0.24667627, step = 6700 (0.221 sec)
INFO:tensorflow:global_step/sec: 462.774
INFO:tensorflow:global_step/sec: 462.774
INFO:tensorflow:loss = 0.40046322, step = 6800 (0.216 sec)
INFO:tensorflow:loss = 0.40046322, step = 6800 (0.216 sec)
INFO:tensorflow:global_step/sec: 460.854
INFO:tensorflow:global_step/sec: 460.854
INFO:tensorflow:loss = 0.14105138, step = 6900 (0.217 sec)
INFO:tensorflow:loss = 0.14105138, step = 6900 (0.217 sec)
INFO:tensorflow:Requesting early stopping at global step 6916
INFO:tensorflow:Requesting early stopping at global step 6916
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6917...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6917...
INFO:tensorflow:Saving checkpoints for 6917 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Saving checkpoints for 6917 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6917...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6917...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-09-22T20:07:35
INFO:tensorflow:Starting evaluation at 2021-09-22T20:07:35
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Restoring parameters from /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Inference Time : 0.79520s
INFO:tensorflow:Inference Time : 0.79520s
INFO:tensorflow:Finished evaluation at 2021-09-22-20:07:36
INFO:tensorflow:Finished evaluation at 2021-09-22-20:07:36
INFO:tensorflow:Saving dict for global step 6917: global_step = 6917, loss = 0.227278
INFO:tensorflow:Saving dict for global step 6917: global_step = 6917, loss = 0.227278
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6917: /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6917: /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Loss for final step: 0.13882703.
INFO:tensorflow:Loss for final step: 0.13882703.
({'loss': 0.227278, 'global_step': 6917}, [])
TensorFlow 2: Parada antecipada com retorno de chamada integrado e Model.fit
Prepare o conjunto de dados MNIST e um modelo Keras simples:
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.005),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
No TensorFlow 2, ao usar o Keras Model.fit (ou Model.evaluate ) integrado, você pode configurar a interrupção antecipada passando um retorno de chamada integrado — tf.keras.callbacks.EarlyStopping — para o parâmetro callbacks de Model.fit .
O retorno de chamada EarlyStopping monitora uma métrica especificada pelo usuário e encerra o treinamento quando ele para de melhorar. (Verifique o Treinamento e avaliação com os métodos integrados ou os documentos da API para obter mais informações.)
Abaixo está um exemplo de um retorno de chamada de parada antecipada que monitora a perda e interrompe o treinamento depois que o número de épocas que não mostram melhorias é definido como 3 ( patience ):
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 5s 8ms/step - loss: 0.2371 - sparse_categorical_accuracy: 0.9293 - val_loss: 0.1334 - val_sparse_categorical_accuracy: 0.9611 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.1028 - sparse_categorical_accuracy: 0.9686 - val_loss: 0.1062 - val_sparse_categorical_accuracy: 0.9667 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0703 - sparse_categorical_accuracy: 0.9783 - val_loss: 0.0993 - val_sparse_categorical_accuracy: 0.9707 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0552 - sparse_categorical_accuracy: 0.9822 - val_loss: 0.1040 - val_sparse_categorical_accuracy: 0.9680 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0420 - sparse_categorical_accuracy: 0.9865 - val_loss: 0.1033 - val_sparse_categorical_accuracy: 0.9716 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0387 - sparse_categorical_accuracy: 0.9871 - val_loss: 0.1167 - val_sparse_categorical_accuracy: 0.9691 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0321 - sparse_categorical_accuracy: 0.9893 - val_loss: 0.1396 - val_sparse_categorical_accuracy: 0.9672 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0285 - sparse_categorical_accuracy: 0.9902 - val_loss: 0.1397 - val_sparse_categorical_accuracy: 0.9671 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0263 - sparse_categorical_accuracy: 0.9915 - val_loss: 0.1296 - val_sparse_categorical_accuracy: 0.9715 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0250 - sparse_categorical_accuracy: 0.9915 - val_loss: 0.1440 - val_sparse_categorical_accuracy: 0.9715 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0274 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.1439 - val_sparse_categorical_accuracy: 0.9710 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0241 - sparse_categorical_accuracy: 0.9923 - val_loss: 0.1429 - val_sparse_categorical_accuracy: 0.9718 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0205 - sparse_categorical_accuracy: 0.9929 - val_loss: 0.1451 - val_sparse_categorical_accuracy: 0.9753 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0196 - sparse_categorical_accuracy: 0.9936 - val_loss: 0.1562 - val_sparse_categorical_accuracy: 0.9750 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0214 - sparse_categorical_accuracy: 0.9930 - val_loss: 0.1531 - val_sparse_categorical_accuracy: 0.9748 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0178 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1712 - val_sparse_categorical_accuracy: 0.9731 Epoch 17/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0177 - sparse_categorical_accuracy: 0.9947 - val_loss: 0.1715 - val_sparse_categorical_accuracy: 0.9755 Epoch 18/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1826 - val_sparse_categorical_accuracy: 0.9730 Epoch 19/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0188 - sparse_categorical_accuracy: 0.9942 - val_loss: 0.1919 - val_sparse_categorical_accuracy: 0.9732 Epoch 20/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0190 - sparse_categorical_accuracy: 0.9944 - val_loss: 0.1703 - val_sparse_categorical_accuracy: 0.9777 Epoch 21/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0153 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.1725 - val_sparse_categorical_accuracy: 0.9764 21
TensorFlow 2: Parada antecipada com um retorno de chamada personalizado e Model.fit
Você também pode implementar um retorno de chamada de parada antecipada personalizado , que também pode ser passado para o parâmetro de callbacks de chamada de Model.fit (ou Model.evaluate ).
Neste exemplo, o processo de treinamento é interrompido quando self.model.stop_training é definido como True :
class LimitTrainingTime(tf.keras.callbacks.Callback):
def __init__(self, max_time_s):
super().__init__()
self.max_time_s = max_time_s
self.start_time = None
def on_train_begin(self, logs):
self.start_time = time.time()
def on_train_batch_end(self, batch, logs):
now = time.time()
if now - self.start_time > self.max_time_s:
self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0131 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.1911 - val_sparse_categorical_accuracy: 0.9749 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0133 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1999 - val_sparse_categorical_accuracy: 0.9755 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0153 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1927 - val_sparse_categorical_accuracy: 0.9770 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0145 - sparse_categorical_accuracy: 0.9957 - val_loss: 0.2279 - val_sparse_categorical_accuracy: 0.9753 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2272 - val_sparse_categorical_accuracy: 0.9755 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0132 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2352 - val_sparse_categorical_accuracy: 0.9747 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0144 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.2421 - val_sparse_categorical_accuracy: 0.9734 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.2260 - val_sparse_categorical_accuracy: 0.9785 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.2472 - val_sparse_categorical_accuracy: 0.9752 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.2166 - val_sparse_categorical_accuracy: 0.9768 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0145 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2289 - val_sparse_categorical_accuracy: 0.9781 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0119 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.2310 - val_sparse_categorical_accuracy: 0.9777 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0144 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2617 - val_sparse_categorical_accuracy: 0.9781 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0119 - sparse_categorical_accuracy: 0.9972 - val_loss: 0.3007 - val_sparse_categorical_accuracy: 0.9754 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0150 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.3014 - val_sparse_categorical_accuracy: 0.9767 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2815 - val_sparse_categorical_accuracy: 0.9750 Epoch 17/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2606 - val_sparse_categorical_accuracy: 0.9765 Epoch 18/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0103 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.2602 - val_sparse_categorical_accuracy: 0.9777 Epoch 19/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0098 - sparse_categorical_accuracy: 0.9979 - val_loss: 0.2594 - val_sparse_categorical_accuracy: 0.9780 Epoch 20/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.3008 - val_sparse_categorical_accuracy: 0.9755 Epoch 21/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0110 - sparse_categorical_accuracy: 0.9974 - val_loss: 0.2662 - val_sparse_categorical_accuracy: 0.9765 Epoch 22/100 469/469 [==============================] - 1s 1ms/step - loss: 0.0083 - sparse_categorical_accuracy: 0.9978 - val_loss: 0.2587 - val_sparse_categorical_accuracy: 0.9797 22
TensorFlow 2: Parada antecipada com um loop de treinamento personalizado
No TensorFlow 2, você pode implementar a interrupção antecipada em um loop de treinamento personalizado se não estiver treinando e avaliando com os métodos Keras integrados .
Comece usando as APIs Keras para definir outro modelo simples, um otimizador, uma função de perda e métricas:
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
Defina as funções de atualização de parâmetro com tf.GradientTape e o decorador @tf.function para acelerar:
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
train_loss_metric.update_state(y, logits)
return loss_value
@tf.function
def test_step(x, y):
logits = model(x, training=False)
val_acc_metric.update_state(y, logits)
val_loss_metric.update_state(y, logits)
Em seguida, escreva um loop de treinamento personalizado, onde você pode implementar sua regra de parada antecipada manualmente.
O exemplo abaixo mostra como parar o treinamento quando a perda de validação não melhora em um determinado número de épocas:
epochs = 100
patience = 5
wait = 0
best = 0
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
loss_value = train_step(x_batch_train, y_batch_train)
if step % 200 == 0:
print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
print("Seen so far: %s samples" % ((step + 1) * 128))
train_acc = train_acc_metric.result()
train_loss = train_loss_metric.result()
train_acc_metric.reset_states()
train_loss_metric.reset_states()
print("Training acc over epoch: %.4f" % (train_acc.numpy()))
for x_batch_val, y_batch_val in ds_test:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_loss = val_loss_metric.result()
val_acc_metric.reset_states()
val_loss_metric.reset_states()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
# The early stopping strategy: stop the training if `val_loss` does not
# decrease over a certain number of epochs.
wait += 1
if val_loss > best:
best = val_loss
wait = 0
if wait >= patience:
break
Start of epoch 0 Training loss at step 0: 2.3073 Seen so far: 128 samples Training loss at step 200: 0.2164 Seen so far: 25728 samples Training loss at step 400: 0.2186 Seen so far: 51328 samples Training acc over epoch: 0.9321 Validation acc: 0.9644 Time taken: 1.73s Start of epoch 1 Training loss at step 0: 0.0733 Seen so far: 128 samples Training loss at step 200: 0.1581 Seen so far: 25728 samples Training loss at step 400: 0.1625 Seen so far: 51328 samples Training acc over epoch: 0.9704 Validation acc: 0.9681 Time taken: 1.23s Start of epoch 2 Training loss at step 0: 0.0501 Seen so far: 128 samples Training loss at step 200: 0.1389 Seen so far: 25728 samples Training loss at step 400: 0.1495 Seen so far: 51328 samples Training acc over epoch: 0.9779 Validation acc: 0.9703 Time taken: 1.17s Start of epoch 3 Training loss at step 0: 0.0513 Seen so far: 128 samples Training loss at step 200: 0.0638 Seen so far: 25728 samples Training loss at step 400: 0.0930 Seen so far: 51328 samples Training acc over epoch: 0.9830 Validation acc: 0.9719 Time taken: 1.20s Start of epoch 4 Training loss at step 0: 0.0251 Seen so far: 128 samples Training loss at step 200: 0.0482 Seen so far: 25728 samples Training loss at step 400: 0.0872 Seen so far: 51328 samples Training acc over epoch: 0.9849 Validation acc: 0.9672 Time taken: 1.18s Start of epoch 5 Training loss at step 0: 0.0417 Seen so far: 128 samples Training loss at step 200: 0.0302 Seen so far: 25728 samples Training loss at step 400: 0.0362 Seen so far: 51328 samples Training acc over epoch: 0.9878 Validation acc: 0.9703 Time taken: 1.21s
Próximos passos
- Saiba mais sobre a API de retorno de chamada de interrupção antecipada integrada do Keras nos documentos da API .
- Aprenda a escrever retornos de chamada Keras personalizados , incluindo interrupção antecipada com perda mínima .
- Saiba mais sobre Treinamento e avaliação com os métodos integrados do Keras .
- Explore técnicas comuns de regularização no tutorial Overfit e underfit que usa o retorno de chamada
EarlyStopping.
Veja no TensorFlow.org
Executar no Google Colab
Ver fonte no GitHub
Baixar caderno