Ver en TensorFlow.org | Ejecutar en Google Colab | Ver fuente en GitHub | Descargar notebook |
Un callback personalizado es una herramienta poderosa para personalizar el comportamiento de un modelo de Keras durante el entrenamiento, evaluacion o inferencia, incluyendo la lectura/cambio del modelo de Keras. Ejemplos incluyen tf.keras.callbacks.TensorBoard
, donde se pueden exportar y visualizar el progreso del entrenamiento y los resultados con TensorBoard, o tf.keras.callbacks.ModelCheckpoint
donde el modelo es automaticamente guardado durante el entrenamiento, entre otros. En esta guia aprenderas que es un callback de Keras, cuando se llama, que puede hacer y como puedes construir una propia. Al final de la guia habra demos para la creacion de aplicaciones simples de callback para ayudarte a empezar tu propio callback personalizados.
Setup
import tensorflow as tf
Introduccion a los callbacks de Keras
En Keras 'Callback' es una clase de python destinada a ser subclase para proporcionar una funcionalidad específica, con un conjunto de métodos llamados en varias etapas de entrenamiento (incluyendo el inicio y fin de los batch/epoch), pruebas y predicciones. Los Callbacks son útiles para tener visibilidad de los estados internos y las estadísticas del modelo durante el entrenamiento. Puedes pasar una lista de callbacks (como argumento de palabra clave callbacks
) a cualquiera de los siguientes metodos tf.keras.Model.fit ()
,tf.keras.Model.evaluate ()
ytf.keras.Model .predict ()
. Los metodos de los callbacks se llamaran en diferentes etapas del entrenamiento/evaluación/inferencia.
Para comenzar, importemos TensorDlow y definamos un modelo secuencial sencillo en Keras:
# Definir el modelo de Keras model al que se le agregaran los callbacks
def get_model():
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, activation = 'linear', input_dim = 784))
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.1), loss='mean_squared_error', metrics=['mae'])
return model
Luego, cara el dataset de MNIST para entrenamiento y pruebas de la APLI de datasetws de Keras:
# Cargar los datos de ejemplo de MNIST data y preprocesarlos
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step
Ahora, define un callback simple y personalizado para rastrear el inicio y fin de cada batch de datos. Durante esas llamadas, imprime el indice del batch actual.
import datetime
class MyCustomCallback(tf.keras.callbacks.Callback):
def on_train_batch_begin(self, batch, logs=None):
print('Entrenamiento: batch {} comienza en {}'.format(batch, datetime.datetime.now().time()))
def on_train_batch_end(self, batch, logs=None):
print('Entrenamiento: batch {} termina en {}'.format(batch, datetime.datetime.now().time()))
def on_test_batch_begin(self, batch, logs=None):
print('Evaluacion: batch {} comienza en {}'.format(batch, datetime.datetime.now().time()))
def on_test_batch_end(self, batch, logs=None):
print('Evaluacion: batch {} termina en {}'.format(batch, datetime.datetime.now().time()))
Dar un callback mara los metodos del modelo tales como tf.keras.Model.fit()
aseguran que los metodos son llamados en dichas etapas:
model = get_model()
_ = model.fit(x_train, y_train,
batch_size=64,
epochs=1,
steps_per_epoch=5,
verbose=0,
callbacks=[MyCustomCallback()])
Entrenamiento: batch 0 comienza en 00:09:02.873100 Entrenamiento: batch 0 termina en 00:09:03.419566 Entrenamiento: batch 1 comienza en 00:09:03.419743 Entrenamiento: batch 1 termina en 00:09:03.422108 Entrenamiento: batch 2 comienza en 00:09:03.422228 Entrenamiento: batch 2 termina en 00:09:03.423979 Entrenamiento: batch 3 comienza en 00:09:03.424081 Entrenamiento: batch 3 termina en 00:09:03.425804 Entrenamiento: batch 4 comienza en 00:09:03.425909 Entrenamiento: batch 4 termina en 00:09:03.427571
Metodos del Modelo que aceptan callbacks
Los usuarios pueden dar una lista de callbacks para los siguientes metodos de tf.keras.Model
:
fit()
, fit_generator()
Entrena el modelo por una cantidad determinada de epochs (iteraciones en un dataset, o para los datos determinados por un generador de Python que va batch-por-batch).
evaluate()
, evaluate_generator()
Evalua el modelo para determinados datos o generador de datos. Regresa la perdida (loss) y valores metricos para la evaluacion.
predict()
, predict_generator()
Genera las predicciones a regresar para los datos ingresados o el generador de datos. NOTA: Toda la documentacion esta en ingles.
_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=5,
callbacks=[MyCustomCallback()])
Evaluacion: batch 0 comienza en 00:09:03.491773 Evaluacion: batch 0 termina en 00:09:03.571752 Evaluacion: batch 1 comienza en 00:09:03.571900 Evaluacion: batch 1 termina en 00:09:03.573569 Evaluacion: batch 2 comienza en 00:09:03.573676 Evaluacion: batch 2 termina en 00:09:03.575211 Evaluacion: batch 3 comienza en 00:09:03.575316 Evaluacion: batch 3 termina en 00:09:03.576784 Evaluacion: batch 4 comienza en 00:09:03.576907 Evaluacion: batch 4 termina en 00:09:03.578332
Una revision de los metodos de callback
Metodos comunes para entrenamiento/pruebas/prediccion
Para entrenamiento, pruebas y prediccion, los siguientes metodos se han previsto para ser sobreescritos.
on_(train|test|predict)_begin(self, logs=None)
Llamado al inicio de fit
/evaluate
/predict
.
on_(train|test|predict)_end(self, logs=None)
Llamado al fin de fit
/evaluate
/predict
.
on_(train|test|predict)_batch_begin(self, batch, logs=None)
Llamado justo antes de procesar un batch durante entrenamiento/pruebas/prediccion. Dentro de este metodo, logs
es un diccionario con las llaves batch
y size
disponibles, representando el numero de batch actual y las dimensiones del mismo.
on_(train|test|predict)_batch_end(self, batch, logs=None)
Llamado al final del entrenamiento/pruebas/prediccion de un batch. dentro de este metodo, logs
es un diccionario que contiene resultados metricos con estado.
Entrenamiento de metodos especificos
Adicionalmente, para el entrenamiento, los siguientes metodos son provistos.
on_epoch_begin(self, epoch, logs=None)
Llamado al inicio de una epoch durante el entrenamiento.
on_epoch_end(self, epoch, logs=None)
Llamado al final de una epoch durante el entrenamiento.
Uso del diccionario logs
El diccionario logs
contiene el valor de perdida (loss), y todas las metricas pertinentes al final de un batch o epoch. El ejemplo a continuacion incluye la perdidad (loss) y el MAE (Mean Absolute Error).
class LossAndErrorPrintingCallback(tf.keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None):
print('Para el batch {}, la perdida (loss) es {:7.2f}.'.format(batch, logs['loss']))
def on_test_batch_end(self, batch, logs=None):
print('Para el batch {}, la perdida (loss) es {:7.2f}.'.format(batch, logs['loss']))
def on_epoch_end(self, epoch, logs=None):
print('La perdida promedio para la epoch {} es {:7.2f} y el MAE es {:7.2f}.'.format(epoch, logs['loss'], logs['mae']))
model = get_model()
_ = model.fit(x_train, y_train,
batch_size=64,
steps_per_epoch=5,
epochs=3,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()])
Para el batch 0, la perdida (loss) es 29.35. Para el batch 1, la perdida (loss) es 413.19. Para el batch 2, la perdida (loss) es 282.48. Para el batch 3, la perdida (loss) es 214.16. Para el batch 4, la perdida (loss) es 173.17. La perdida promedio para la epoch 0 es 173.17 y el MAE es 8.02. Para el batch 0, la perdida (loss) es 6.95. Para el batch 1, la perdida (loss) es 6.50. Para el batch 2, la perdida (loss) es 6.06. Para el batch 3, la perdida (loss) es 6.25. Para el batch 4, la perdida (loss) es 5.70. La perdida promedio para la epoch 1 es 5.70 y el MAE es 1.98. Para el batch 0, la perdida (loss) es 4.78. Para el batch 1, la perdida (loss) es 4.86. Para el batch 2, la perdida (loss) es 5.66. Para el batch 3, la perdida (loss) es 5.81. Para el batch 4, la perdida (loss) es 6.02. La perdida promedio para la epoch 2 es 6.02 y el MAE es 1.97.
De manera similar, uno puede proveer callbacks en las llamadas a evaluate()
.
_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=20,
callbacks=[LossAndErrorPrintingCallback()])
Para el batch 0, la perdida (loss) es 7.77. Para el batch 1, la perdida (loss) es 7.51. Para el batch 2, la perdida (loss) es 7.58. Para el batch 3, la perdida (loss) es 7.67. Para el batch 4, la perdida (loss) es 7.76. Para el batch 5, la perdida (loss) es 7.83. Para el batch 6, la perdida (loss) es 7.89. Para el batch 7, la perdida (loss) es 7.85. Para el batch 8, la perdida (loss) es 7.87. Para el batch 9, la perdida (loss) es 7.98. Para el batch 10, la perdida (loss) es 7.95. Para el batch 11, la perdida (loss) es 7.97. Para el batch 12, la perdida (loss) es 7.98. Para el batch 13, la perdida (loss) es 8.08. Para el batch 14, la perdida (loss) es 8.07. Para el batch 15, la perdida (loss) es 7.99. Para el batch 16, la perdida (loss) es 8.05. Para el batch 17, la perdida (loss) es 8.05. Para el batch 18, la perdida (loss) es 8.12. Para el batch 19, la perdida (loss) es 8.13.
Ejemplos de aplicaciones de callbacks de Keras
La siguiente seccion te guiara en la creacion de una aplicacion de callback simple.
Detencion anticipada con perdida minima.
El primer ejemplo muestra la creacion de un Callback
que detiene el entrenamiento de Keras cuando se alcanza el minimo de perdida mutando el atributomodel.stop_training
(boolean). Opcionalmente, el usuario puede proporcionar el argumento patience
para especificar cuantas epochs debe esperar el entrenamiento antes de detenerse.
tf.keras.callbacks.EarlyStopping
proporciona una implementación mas completa y general.
import numpy as np
class EarlyStoppingAtMinLoss(tf.keras.callbacks.Callback):
"""Detener el entrenamiento cuando la perdida (loss) esta en su minimo, i.e. la perdida (loss) deja de disminuir.
Arguments:
patience: Numero de epochs a esperar despues de que el min ha sido alcanzaado. Despues de este numero
de no mejoras, el entrenamiento para.
"""
def __init__(self, patience=0):
super(EarlyStoppingAtMinLoss, self).__init__()
self.patience = patience
# best_weights para almacenar los pesos en los cuales ocurre la perdida minima.
self.best_weights = None
def on_train_begin(self, logs=None):
# El numero de epoch que ha esperado cuando la perdida ya no es minima.
self.wait = 0
# El epoch en el que en entrenamiento se detiene.
self.stopped_epoch = 0
# Initialize el best como infinito.
self.best = np.inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get('loss')
if np.less(current, self.best):
self.best = current
self.wait = 0
# Guardar los mejores pesos si el resultado actual es mejor (menos).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print('Restaurando los pesos del modelo del final de la mejor epoch.')
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print('Epoch %05d: Detencion anticipada' % (self.stopped_epoch + 1))
model = get_model()
_ = model.fit(x_train, y_train,
batch_size=64,
steps_per_epoch=5,
epochs=30,
verbose=0,
callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()])
Para el batch 0, la perdida (loss) es 29.01. Para el batch 1, la perdida (loss) es 469.72. Para el batch 2, la perdida (loss) es 318.63. Para el batch 3, la perdida (loss) es 241.80. Para el batch 4, la perdida (loss) es 194.70. La perdida promedio para la epoch 0 es 194.70 y el MAE es 8.03. Para el batch 0, la perdida (loss) es 8.98. Para el batch 1, la perdida (loss) es 7.03. Para el batch 2, la perdida (loss) es 6.32. Para el batch 3, la perdida (loss) es 6.01. Para el batch 4, la perdida (loss) es 5.88. La perdida promedio para la epoch 1 es 5.88 y el MAE es 1.96. Para el batch 0, la perdida (loss) es 4.63. Para el batch 1, la perdida (loss) es 4.54. Para el batch 2, la perdida (loss) es 4.39. Para el batch 3, la perdida (loss) es 4.79. Para el batch 4, la perdida (loss) es 4.98. La perdida promedio para la epoch 2 es 4.98 y el MAE es 1.79. Para el batch 0, la perdida (loss) es 5.49. Para el batch 1, la perdida (loss) es 4.72. Para el batch 2, la perdida (loss) es 4.87. Para el batch 3, la perdida (loss) es 5.56. Para el batch 4, la perdida (loss) es 8.84. La perdida promedio para la epoch 3 es 8.84 y el MAE es 2.35. Restaurando los pesos del modelo del final de la mejor epoch. Epoch 00004: Detencion anticipada
Programacion del Learning Rate
Algo que es hecho comunmente en el entrenamiento de un modelo es cambiar el learning rate conforme pasan mas epochs. El backend de Keras expone la API get_value
la cual puede ser usada para definir las variables. En este ejemplo estamos mostrando como un Callback personalizado puede ser usado para cambiar dinamicamente el learning rate.
Nota: este es solo una implementacion de ejemplo, callbacks.LearningRateScheduler
y keras.optimizers.schedules
contienen implementaciones mas generales.
class LearningRateScheduler(tf.keras.callbacks.Callback):
"""Planificador de Learning rate que define el learning rate deacuerdo a lo programado.
Arguments:
schedule: una funcion que toma el indice del epoch
(entero, indexado desde 0) y el learning rate actual
como entradas y regresa un nuevo learning rate como salida (float).
"""
def __init__(self, schedule):
super(LearningRateScheduler, self).__init__()
self.schedule = schedule
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
# Obtener el learning rate actua del optimizer del modelo.
lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
# Llamar la funcion schedule para obtener el learning rate programado.
scheduled_lr = self.schedule(epoch, lr)
# Definir el valor en el optimized antes de que la epoch comience
tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
print('\nEpoch %05d: Learning rate is %6.4f.' % (epoch, scheduled_lr))
LR_SCHEDULE = [
# (epoch a comenzar, learning rate) tupla
(3, 0.05), (6, 0.01), (9, 0.005), (12, 0.001)
]
def lr_schedule(epoch, lr):
"""Funcion de ayuda para recuperar el learning rate programado basado en la epoch."""
if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
return lr
for i in range(len(LR_SCHEDULE)):
if epoch == LR_SCHEDULE[i][0]:
return LR_SCHEDULE[i][1]
return lr
model = get_model()
_ = model.fit(x_train, y_train,
batch_size=64,
steps_per_epoch=5,
epochs=15,
verbose=0,
callbacks=[LossAndErrorPrintingCallback(), LearningRateScheduler(lr_schedule)])
Epoch 00000: Learning rate is 0.1000. Para el batch 0, la perdida (loss) es 31.79. Para el batch 1, la perdida (loss) es 365.90. Para el batch 2, la perdida (loss) es 254.34. Para el batch 3, la perdida (loss) es 194.62. Para el batch 4, la perdida (loss) es 157.31. La perdida promedio para la epoch 0 es 157.31 y el MAE es 7.92. Epoch 00001: Learning rate is 0.1000. Para el batch 0, la perdida (loss) es 6.23. Para el batch 1, la perdida (loss) es 6.23. Para el batch 2, la perdida (loss) es 6.31. Para el batch 3, la perdida (loss) es 6.04. Para el batch 4, la perdida (loss) es 6.08. La perdida promedio para la epoch 1 es 6.08 y el MAE es 2.06. Epoch 00002: Learning rate is 0.1000. Para el batch 0, la perdida (loss) es 10.44. Para el batch 1, la perdida (loss) es 10.94. Para el batch 2, la perdida (loss) es 12.18. Para el batch 3, la perdida (loss) es 14.24. Para el batch 4, la perdida (loss) es 16.86. La perdida promedio para la epoch 2 es 16.86 y el MAE es 3.41. Epoch 00003: Learning rate is 0.0500. Para el batch 0, la perdida (loss) es 35.25. Para el batch 1, la perdida (loss) es 20.86. Para el batch 2, la perdida (loss) es 15.10. Para el batch 3, la perdida (loss) es 12.15. Para el batch 4, la perdida (loss) es 10.47. La perdida promedio para la epoch 3 es 10.47 y el MAE es 2.31. Epoch 00004: Learning rate is 0.0500. Para el batch 0, la perdida (loss) es 4.03. Para el batch 1, la perdida (loss) es 4.65. Para el batch 2, la perdida (loss) es 4.61. Para el batch 3, la perdida (loss) es 4.65. Para el batch 4, la perdida (loss) es 4.74. La perdida promedio para la epoch 4 es 4.74 y el MAE es 1.72. Epoch 00005: Learning rate is 0.0500. Para el batch 0, la perdida (loss) es 4.02. Para el batch 1, la perdida (loss) es 3.93. Para el batch 2, la perdida (loss) es 3.99. Para el batch 3, la perdida (loss) es 3.83. Para el batch 4, la perdida (loss) es 4.04. La perdida promedio para la epoch 5 es 4.04 y el MAE es 1.59. Epoch 00006: Learning rate is 0.0100. Para el batch 0, la perdida (loss) es 4.38. Para el batch 1, la perdida (loss) es 4.90. Para el batch 2, la perdida (loss) es 4.83. Para el batch 3, la perdida (loss) es 5.02. Para el batch 4, la perdida (loss) es 5.04. La perdida promedio para la epoch 6 es 5.04 y el MAE es 1.77. Epoch 00007: Learning rate is 0.0100. Para el batch 0, la perdida (loss) es 5.85. Para el batch 1, la perdida (loss) es 5.87. Para el batch 2, la perdida (loss) es 5.11. Para el batch 3, la perdida (loss) es 5.38. Para el batch 4, la perdida (loss) es 4.97. La perdida promedio para la epoch 7 es 4.97 y el MAE es 1.76. Epoch 00008: Learning rate is 0.0100. Para el batch 0, la perdida (loss) es 4.39. Para el batch 1, la perdida (loss) es 3.87. Para el batch 2, la perdida (loss) es 3.69. Para el batch 3, la perdida (loss) es 3.95. Para el batch 4, la perdida (loss) es 4.16. La perdida promedio para la epoch 8 es 4.16 y el MAE es 1.63. Epoch 00009: Learning rate is 0.0050. Para el batch 0, la perdida (loss) es 4.89. Para el batch 1, la perdida (loss) es 4.67. Para el batch 2, la perdida (loss) es 4.33. Para el batch 3, la perdida (loss) es 4.21. Para el batch 4, la perdida (loss) es 4.04. La perdida promedio para la epoch 9 es 4.04 y el MAE es 1.61. Epoch 00010: Learning rate is 0.0050. Para el batch 0, la perdida (loss) es 4.09. Para el batch 1, la perdida (loss) es 4.51. Para el batch 2, la perdida (loss) es 4.56. Para el batch 3, la perdida (loss) es 4.31. Para el batch 4, la perdida (loss) es 4.45. La perdida promedio para la epoch 10 es 4.45 y el MAE es 1.69. Epoch 00011: Learning rate is 0.0050. Para el batch 0, la perdida (loss) es 3.31. Para el batch 1, la perdida (loss) es 3.34. Para el batch 2, la perdida (loss) es 4.03. Para el batch 3, la perdida (loss) es 4.05. Para el batch 4, la perdida (loss) es 4.02. La perdida promedio para la epoch 11 es 4.02 y el MAE es 1.53. Epoch 00012: Learning rate is 0.0010. Para el batch 0, la perdida (loss) es 3.62. Para el batch 1, la perdida (loss) es 4.47. Para el batch 2, la perdida (loss) es 4.52. Para el batch 3, la perdida (loss) es 4.33. Para el batch 4, la perdida (loss) es 4.42. La perdida promedio para la epoch 12 es 4.42 y el MAE es 1.67. Epoch 00013: Learning rate is 0.0010. Para el batch 0, la perdida (loss) es 4.11. Para el batch 1, la perdida (loss) es 4.96. Para el batch 2, la perdida (loss) es 4.48. Para el batch 3, la perdida (loss) es 4.53. Para el batch 4, la perdida (loss) es 4.36. La perdida promedio para la epoch 13 es 4.36 y el MAE es 1.67. Epoch 00014: Learning rate is 0.0010. Para el batch 0, la perdida (loss) es 3.64. Para el batch 1, la perdida (loss) es 3.94. Para el batch 2, la perdida (loss) es 3.97. Para el batch 3, la perdida (loss) es 4.12. Para el batch 4, la perdida (loss) es 4.00. La perdida promedio para la epoch 14 es 4.00 y el MAE es 1.57.
Callbacks de Keras estandar
Asegurate de revisar los callbacks de Keras preexistentes visitando la documentacion de la api. Las aplicaciones incluyen el registro a CSV, guardar el modelo, visualizar en TensorBoard y mucho mas.
NOTA: La documentacion aun esta en ingles