Ver en TensorFlow.org | Ejecutar en Google Colab | Ver fuente en GitHub | Descargar libreta |
En TensorFlow 1, usa tf.estimator.LoggingTensorHook
para monitorear y registrar tensores, mientras que tf.estimator.StopAtStepHook
ayuda a detener el entrenamiento en un paso específico cuando se entrena con tf.estimator.Estimator
. Este cuaderno demuestra cómo migrar de estas API a sus equivalentes en TensorFlow 2 mediante devoluciones de llamadas personalizadas de Keras ( tf.keras.callbacks.Callback
) con Model.fit
.
Las devoluciones de llamada de Keras son objetos que se llaman en diferentes puntos durante el entrenamiento/evaluación/predicción en las API Model.evaluate
de Model.predict
Model.fit
. Puede obtener más información sobre las devoluciones de llamada en los documentos de tf.keras.callbacks.Callback
API, así como en las guías Escribir sus propias devoluciones de llamada y Capacitación y evaluación con los métodos integrados (la sección Uso de devoluciones de llamada). Para migrar de SessionRunHook
en TensorFlow 1 a las devoluciones de llamada de Keras en TensorFlow 2, consulta la guía Migrar capacitación con lógica asistida .
Configuración
Comience con importaciones y un conjunto de datos simple para fines de demostración:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
# Define an input function.
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
TensorFlow 1: registra tensores y deja de entrenar con las API de tf.estimator
En TensorFlow 1, define varios ganchos para controlar el comportamiento de entrenamiento. Luego, pasa estos ganchos a tf.estimator.EstimatorSpec
.
En el siguiente ejemplo:
- Para monitorear/registrar tensores, por ejemplo, modelar pesos o pérdidas, usa
tf.estimator.LoggingTensorHook
(tf.train.LoggingTensorHook
es su alias). - Para detener el entrenamiento en un paso específico, usa
tf.estimator.StopAtStepHook
(tf.train.StopAtStepHook
es su alias).
def _model_fn(features, labels, mode):
dense = tf1.layers.Dense(1)
logits = dense(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
# Define the stop hook.
stop_hook = tf1.train.StopAtStepHook(num_steps=2)
# Access tensors to be logged by names.
kernel_name = tf.identity(dense.weights[0])
bias_name = tf.identity(dense.weights[1])
logging_weight_hook = tf1.train.LoggingTensorHook(
tensors=[kernel_name, bias_name],
every_n_iter=1)
# Log the training loss by the tensor object.
logging_loss_hook = tf1.train.LoggingTensorHook(
{'loss from LoggingTensorHook': loss},
every_n_secs=3)
# Pass all hooks to `EstimatorSpec`.
return tf1.estimator.EstimatorSpec(mode,
loss=loss,
train_op=train_op,
training_hooks=[stop_hook,
logging_weight_hook,
logging_loss_hook])
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
estimator.train(_input_fn)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp3q__3yt7 INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp3q__3yt7', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp3q__3yt7/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.025395721, step = 0 INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-1.0769143] [ 1.0241832]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [0.] INFO:tensorflow:loss from LoggingTensorHook = 0.025395721 INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-1.1124082] [ 0.9824805]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [-0.03549388] (0.026 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmp3q__3yt7/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Loss for final step: 0.09248222. <tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f05ec414d10>
TensorFlow 2: registre tensores y detenga el entrenamiento con devoluciones de llamada personalizadas y Model.fit
En TensorFlow 2, cuando usa Keras Model.fit
(o Model.evaluate
) integrado para entrenamiento/evaluación, puede configurar el monitoreo de tensor y la detención del entrenamiento definiendo Keras tf.keras.callbacks.Callback
personalizados. Luego, los pasa al parámetro de callbacks
de llamada de Model.fit
(o Model.evaluate
). (Obtenga más información en la guía Cómo escribir sus propias devoluciones de llamada).
En el siguiente ejemplo:
- Para recrear las funcionalidades de
StopAtStepHook
, defina una devolución de llamada personalizada (denominadaStopAtStepCallback
continuación) donde anula el métodoon_batch_end
para detener el entrenamiento después de una cierta cantidad de pasos. - Para recrear el comportamiento de
LoggingTensorHook
, defina una devolución de llamada personalizada (LoggingTensorCallback
) donde registre y genere los tensores registrados manualmente, ya que no se admite el acceso a los tensores por nombres. También puede implementar la frecuencia de registro dentro de la devolución de llamada personalizada. El siguiente ejemplo imprimirá los pesos cada dos pasos. También son posibles otras estrategias, como iniciar sesión cada N segundos.
class StopAtStepCallback(tf.keras.callbacks.Callback):
def __init__(self, stop_step=None):
super().__init__()
self._stop_step = stop_step
def on_batch_end(self, batch, logs=None):
if self.model.optimizer.iterations >= self._stop_step:
self.model.stop_training = True
print('\nstop training now')
class LoggingTensorCallback(tf.keras.callbacks.Callback):
def __init__(self, every_n_iter):
super().__init__()
self._every_n_iter = every_n_iter
self._log_count = every_n_iter
def on_batch_end(self, batch, logs=None):
if self._log_count > 0:
self._log_count -= 1
print("Logging Tensor Callback: dense/kernel:",
model.layers[0].weights[0])
print("Logging Tensor Callback: dense/bias:",
model.layers[0].weights[1])
print("Logging Tensor Callback loss:", logs["loss"])
else:
self._log_count -= self._every_n_iter
Cuando haya terminado, pase las nuevas devoluciones de llamada StopAtStepCallback
y LoggingTensorCallback
al parámetro de callbacks
de llamada de Model.fit
:
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer, "mse")
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
model.fit(dataset, callbacks=[StopAtStepCallback(stop_step=2),
LoggingTensorCallback(every_n_iter=2)])
1/3 [=========>....................] - ETA: 0s - loss: 3.2473Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy= array([[-0.27049014], [-0.73790836]], dtype=float32)> Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.04980864], dtype=float32)> Logging Tensor Callback loss: 3.2473244667053223 stop training now Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy= array([[-0.22285421], [-0.6911988 ]], dtype=float32)> Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.09196297], dtype=float32)> Logging Tensor Callback loss: 5.644947052001953 3/3 [==============================] - 0s 4ms/step - loss: 5.6449 <keras.callbacks.History at 0x7f053022be90>
Próximos pasos
Obtenga más información sobre las devoluciones de llamadas en:
- Documentos API:
tf.keras.callbacks.Callback
- Guía: escribir sus propias devoluciones de llamada
- Guía: Capacitación y evaluación con los métodos integrados (la sección Uso de devoluciones de llamada)
También puede encontrar útiles los siguientes recursos relacionados con la migración:
- La guía de migración de detención anticipada:
tf.keras.callbacks.EarlyStopping
es una devolución de llamada de detención anticipada integrada - La guía de migración de TensorBoard: TensorBoard permite rastrear y mostrar métricas
- La guía de migración de capacitación con lógica asistida : de
SessionRunHook
en TensorFlow 1 a devoluciones de llamada de Keras en TensorFlow 2