Escribir callbacks de Keras personalizados

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar notebook

Un callback personalizado es una herramienta poderosa para personalizar el comportamiento de un modelo de Keras durante el entrenamiento, evaluacion o inferencia, incluyendo la lectura/cambio del modelo de Keras. Ejemplos incluyen tf.keras.callbacks.TensorBoard, donde se pueden exportar y visualizar el progreso del entrenamiento y los resultados con TensorBoard, o tf.keras.callbacks.ModelCheckpoint donde el modelo es automaticamente guardado durante el entrenamiento, entre otros. En esta guia aprenderas que es un callback de Keras, cuando se llama, que puede hacer y como puedes construir una propia. Al final de la guia habra demos para la creacion de aplicaciones simples de callback para ayudarte a empezar tu propio callback personalizados.

Setup

import tensorflow as tf

Introduccion a los callbacks de Keras

En Keras 'Callback' es una clase de python destinada a ser subclase para proporcionar una funcionalidad específica, con un conjunto de métodos llamados en varias etapas de entrenamiento (incluyendo el inicio y fin de los batch/epoch), pruebas y predicciones. Los Callbacks son útiles para tener visibilidad de los estados internos y las estadísticas del modelo durante el entrenamiento. Puedes pasar una lista de callbacks (como argumento de palabra clave callbacks) a cualquiera de los siguientes metodos tf.keras.Model.fit (),tf.keras.Model.evaluate ()ytf.keras.Model .predict (). Los metodos de los callbacks se llamaran en diferentes etapas del entrenamiento/evaluación/inferencia.

Para comenzar, importemos TensorDlow y definamos un modelo secuencial sencillo en Keras:

# Definir el modelo de Keras model al que se le agregaran los callbacks
def get_model():
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Dense(1, activation = 'linear', input_dim = 784))
  model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.1), loss='mean_squared_error', metrics=['mae'])
  return model

Luego, cara el dataset de MNIST para entrenamiento y pruebas de la APLI de datasetws de Keras:

# Cargar los datos de ejemplo de MNIST data y preprocesarlos
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

Ahora, define un callback simple y personalizado para rastrear el inicio y fin de cada batch de datos. Durante esas llamadas, imprime el indice del batch actual.

import datetime

class MyCustomCallback(tf.keras.callbacks.Callback):

  def on_train_batch_begin(self, batch, logs=None):
    print('Entrenamiento: batch {} comienza en {}'.format(batch, datetime.datetime.now().time()))

  def on_train_batch_end(self, batch, logs=None):
    print('Entrenamiento: batch {} termina en {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_begin(self, batch, logs=None):
    print('Evaluacion: batch {} comienza en {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_end(self, batch, logs=None):
    print('Evaluacion: batch {} termina en {}'.format(batch, datetime.datetime.now().time()))

Dar un callback mara los metodos del modelo tales como tf.keras.Model.fit() aseguran que los metodos son llamados en dichas etapas:

model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          epochs=1,
          steps_per_epoch=5,
          verbose=0,
          callbacks=[MyCustomCallback()])
Entrenamiento: batch 0 comienza en 00:09:02.873100
Entrenamiento: batch 0 termina en 00:09:03.419566
Entrenamiento: batch 1 comienza en 00:09:03.419743
Entrenamiento: batch 1 termina en 00:09:03.422108
Entrenamiento: batch 2 comienza en 00:09:03.422228
Entrenamiento: batch 2 termina en 00:09:03.423979
Entrenamiento: batch 3 comienza en 00:09:03.424081
Entrenamiento: batch 3 termina en 00:09:03.425804
Entrenamiento: batch 4 comienza en 00:09:03.425909
Entrenamiento: batch 4 termina en 00:09:03.427571

Metodos del Modelo que aceptan callbacks

Los usuarios pueden dar una lista de callbacks para los siguientes metodos de tf.keras.Model:

fit(), fit_generator()

Entrena el modelo por una cantidad determinada de epochs (iteraciones en un dataset, o para los datos determinados por un generador de Python que va batch-por-batch).

evaluate(), evaluate_generator()

Evalua el modelo para determinados datos o generador de datos. Regresa la perdida (loss) y valores metricos para la evaluacion.

predict(), predict_generator()

Genera las predicciones a regresar para los datos ingresados o el generador de datos. NOTA: Toda la documentacion esta en ingles.

_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=5,
          callbacks=[MyCustomCallback()])
Evaluacion: batch 0 comienza en 00:09:03.491773
Evaluacion: batch 0 termina en 00:09:03.571752
Evaluacion: batch 1 comienza en 00:09:03.571900
Evaluacion: batch 1 termina en 00:09:03.573569
Evaluacion: batch 2 comienza en 00:09:03.573676
Evaluacion: batch 2 termina en 00:09:03.575211
Evaluacion: batch 3 comienza en 00:09:03.575316
Evaluacion: batch 3 termina en 00:09:03.576784
Evaluacion: batch 4 comienza en 00:09:03.576907
Evaluacion: batch 4 termina en 00:09:03.578332

Una revision de los metodos de callback

Metodos comunes para entrenamiento/pruebas/prediccion

Para entrenamiento, pruebas y prediccion, los siguientes metodos se han previsto para ser sobreescritos.

on_(train|test|predict)_begin(self, logs=None)

Llamado al inicio de fit/evaluate/predict.

on_(train|test|predict)_end(self, logs=None)

Llamado al fin de fit/evaluate/predict.

on_(train|test|predict)_batch_begin(self, batch, logs=None)

Llamado justo antes de procesar un batch durante entrenamiento/pruebas/prediccion. Dentro de este metodo, logs es un diccionario con las llaves batch y size disponibles, representando el numero de batch actual y las dimensiones del mismo.

on_(train|test|predict)_batch_end(self, batch, logs=None)

Llamado al final del entrenamiento/pruebas/prediccion de un batch. dentro de este metodo, logs es un diccionario que contiene resultados metricos con estado.

Entrenamiento de metodos especificos

Adicionalmente, para el entrenamiento, los siguientes metodos son provistos.

on_epoch_begin(self, epoch, logs=None)

Llamado al inicio de una epoch durante el entrenamiento.

on_epoch_end(self, epoch, logs=None)

Llamado al final de una epoch durante el entrenamiento.

Uso del diccionario logs

El diccionario logs contiene el valor de perdida (loss), y todas las metricas pertinentes al final de un batch o epoch. El ejemplo a continuacion incluye la perdidad (loss) y el MAE (Mean Absolute Error).

class LossAndErrorPrintingCallback(tf.keras.callbacks.Callback):

  def on_train_batch_end(self, batch, logs=None):
    print('Para el batch {}, la perdida (loss) es {:7.2f}.'.format(batch, logs['loss']))

  def on_test_batch_end(self, batch, logs=None):
    print('Para el  batch {}, la perdida (loss) es {:7.2f}.'.format(batch, logs['loss']))

  def on_epoch_end(self, epoch, logs=None):
    print('La perdida promedio para la epoch {} es {:7.2f} y el MAE es {:7.2f}.'.format(epoch, logs['loss'], logs['mae']))

model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          steps_per_epoch=5,
          epochs=3,
          verbose=0,
          callbacks=[LossAndErrorPrintingCallback()])
Para el batch 0, la perdida (loss) es   29.35.
Para el batch 1, la perdida (loss) es  413.19.
Para el batch 2, la perdida (loss) es  282.48.
Para el batch 3, la perdida (loss) es  214.16.
Para el batch 4, la perdida (loss) es  173.17.
La perdida promedio para la epoch 0 es  173.17 y el MAE es    8.02.
Para el batch 0, la perdida (loss) es    6.95.
Para el batch 1, la perdida (loss) es    6.50.
Para el batch 2, la perdida (loss) es    6.06.
Para el batch 3, la perdida (loss) es    6.25.
Para el batch 4, la perdida (loss) es    5.70.
La perdida promedio para la epoch 1 es    5.70 y el MAE es    1.98.
Para el batch 0, la perdida (loss) es    4.78.
Para el batch 1, la perdida (loss) es    4.86.
Para el batch 2, la perdida (loss) es    5.66.
Para el batch 3, la perdida (loss) es    5.81.
Para el batch 4, la perdida (loss) es    6.02.
La perdida promedio para la epoch 2 es    6.02 y el MAE es    1.97.

De manera similar, uno puede proveer callbacks en las llamadas a evaluate().

_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=20,
          callbacks=[LossAndErrorPrintingCallback()])
Para el  batch 0, la perdida (loss) es    7.77.
Para el  batch 1, la perdida (loss) es    7.51.
Para el  batch 2, la perdida (loss) es    7.58.
Para el  batch 3, la perdida (loss) es    7.67.
Para el  batch 4, la perdida (loss) es    7.76.
Para el  batch 5, la perdida (loss) es    7.83.
Para el  batch 6, la perdida (loss) es    7.89.
Para el  batch 7, la perdida (loss) es    7.85.
Para el  batch 8, la perdida (loss) es    7.87.
Para el  batch 9, la perdida (loss) es    7.98.
Para el  batch 10, la perdida (loss) es    7.95.
Para el  batch 11, la perdida (loss) es    7.97.
Para el  batch 12, la perdida (loss) es    7.98.
Para el  batch 13, la perdida (loss) es    8.08.
Para el  batch 14, la perdida (loss) es    8.07.
Para el  batch 15, la perdida (loss) es    7.99.
Para el  batch 16, la perdida (loss) es    8.05.
Para el  batch 17, la perdida (loss) es    8.05.
Para el  batch 18, la perdida (loss) es    8.12.
Para el  batch 19, la perdida (loss) es    8.13.

Ejemplos de aplicaciones de callbacks de Keras

La siguiente seccion te guiara en la creacion de una aplicacion de callback simple.

Detencion anticipada con perdida minima.

El primer ejemplo muestra la creacion de un Callback que detiene el entrenamiento de Keras cuando se alcanza el minimo de perdida mutando el atributomodel.stop_training (boolean). Opcionalmente, el usuario puede proporcionar el argumento patience para especificar cuantas epochs debe esperar el entrenamiento antes de detenerse.

tf.keras.callbacks.EarlyStopping proporciona una implementación mas completa y general.

import numpy as np

class EarlyStoppingAtMinLoss(tf.keras.callbacks.Callback):
  """Detener el entrenamiento cuando la perdida (loss) esta en su minimo, i.e. la perdida (loss) deja de disminuir.

  Arguments:
      patience: Numero de epochs a esperar despues de que el min ha sido alcanzaado. Despues de este numero
      de no mejoras, el entrenamiento para.
  """

  def __init__(self, patience=0):
    super(EarlyStoppingAtMinLoss, self).__init__()

    self.patience = patience

    # best_weights para almacenar los pesos en los cuales ocurre la perdida minima.
    self.best_weights = None

  def on_train_begin(self, logs=None):
    # El numero de epoch que ha esperado cuando la perdida ya no es minima.
    self.wait = 0
    # El epoch en el que en entrenamiento se detiene.
    self.stopped_epoch = 0
    # Initialize el best como infinito.
    self.best = np.inf

  def on_epoch_end(self, epoch, logs=None):
    current = logs.get('loss')
    if np.less(current, self.best):
      self.best = current
      self.wait = 0
      # Guardar los mejores pesos si el resultado actual es mejor (menos).
      self.best_weights = self.model.get_weights()
    else:
      self.wait += 1
      if self.wait >= self.patience:
        self.stopped_epoch = epoch
        self.model.stop_training = True
        print('Restaurando los pesos del modelo del final de la mejor epoch.')
        self.model.set_weights(self.best_weights)

  def on_train_end(self, logs=None):
    if self.stopped_epoch > 0:
      print('Epoch %05d: Detencion anticipada' % (self.stopped_epoch + 1))
model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          steps_per_epoch=5,
          epochs=30,
          verbose=0,
          callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()])
Para el batch 0, la perdida (loss) es   29.01.
Para el batch 1, la perdida (loss) es  469.72.
Para el batch 2, la perdida (loss) es  318.63.
Para el batch 3, la perdida (loss) es  241.80.
Para el batch 4, la perdida (loss) es  194.70.
La perdida promedio para la epoch 0 es  194.70 y el MAE es    8.03.
Para el batch 0, la perdida (loss) es    8.98.
Para el batch 1, la perdida (loss) es    7.03.
Para el batch 2, la perdida (loss) es    6.32.
Para el batch 3, la perdida (loss) es    6.01.
Para el batch 4, la perdida (loss) es    5.88.
La perdida promedio para la epoch 1 es    5.88 y el MAE es    1.96.
Para el batch 0, la perdida (loss) es    4.63.
Para el batch 1, la perdida (loss) es    4.54.
Para el batch 2, la perdida (loss) es    4.39.
Para el batch 3, la perdida (loss) es    4.79.
Para el batch 4, la perdida (loss) es    4.98.
La perdida promedio para la epoch 2 es    4.98 y el MAE es    1.79.
Para el batch 0, la perdida (loss) es    5.49.
Para el batch 1, la perdida (loss) es    4.72.
Para el batch 2, la perdida (loss) es    4.87.
Para el batch 3, la perdida (loss) es    5.56.
Para el batch 4, la perdida (loss) es    8.84.
La perdida promedio para la epoch 3 es    8.84 y el MAE es    2.35.
Restaurando los pesos del modelo del final de la mejor epoch.
Epoch 00004: Detencion anticipada

Programacion del Learning Rate

Algo que es hecho comunmente en el entrenamiento de un modelo es cambiar el learning rate conforme pasan mas epochs. El backend de Keras expone la API get_value la cual puede ser usada para definir las variables. En este ejemplo estamos mostrando como un Callback personalizado puede ser usado para cambiar dinamicamente el learning rate.

Nota: este es solo una implementacion de ejemplo, callbacks.LearningRateScheduler y keras.optimizers.schedules contienen implementaciones mas generales.

class LearningRateScheduler(tf.keras.callbacks.Callback):
  """Planificador de Learning rate que define el learning rate deacuerdo a lo programado.

  Arguments:
      schedule: una funcion que toma el indice del epoch
          (entero, indexado desde 0) y el learning rate actual
          como entradas y regresa un nuevo learning rate como salida (float).
  """

  def __init__(self, schedule):
    super(LearningRateScheduler, self).__init__()
    self.schedule = schedule

  def on_epoch_begin(self, epoch, logs=None):
    if not hasattr(self.model.optimizer, 'lr'):
      raise ValueError('Optimizer must have a "lr" attribute.')
    # Obtener el learning rate actua del optimizer del modelo.
    lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
    # Llamar la funcion schedule para obtener el learning rate programado.
    scheduled_lr = self.schedule(epoch, lr)
    # Definir el valor en el optimized antes de que la epoch comience
    tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
    print('\nEpoch %05d: Learning rate is %6.4f.' % (epoch, scheduled_lr))
LR_SCHEDULE = [
    # (epoch a comenzar, learning rate) tupla
    (3, 0.05), (6, 0.01), (9, 0.005), (12, 0.001)
]

def lr_schedule(epoch, lr):
  """Funcion de ayuda para recuperar el learning rate programado basado en la epoch."""
  if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
    return lr
  for i in range(len(LR_SCHEDULE)):
    if epoch == LR_SCHEDULE[i][0]:
      return LR_SCHEDULE[i][1]
  return lr

model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          steps_per_epoch=5,
          epochs=15,
          verbose=0,
          callbacks=[LossAndErrorPrintingCallback(), LearningRateScheduler(lr_schedule)])
Epoch 00000: Learning rate is 0.1000.
Para el batch 0, la perdida (loss) es   31.79.
Para el batch 1, la perdida (loss) es  365.90.
Para el batch 2, la perdida (loss) es  254.34.
Para el batch 3, la perdida (loss) es  194.62.
Para el batch 4, la perdida (loss) es  157.31.
La perdida promedio para la epoch 0 es  157.31 y el MAE es    7.92.

Epoch 00001: Learning rate is 0.1000.
Para el batch 0, la perdida (loss) es    6.23.
Para el batch 1, la perdida (loss) es    6.23.
Para el batch 2, la perdida (loss) es    6.31.
Para el batch 3, la perdida (loss) es    6.04.
Para el batch 4, la perdida (loss) es    6.08.
La perdida promedio para la epoch 1 es    6.08 y el MAE es    2.06.

Epoch 00002: Learning rate is 0.1000.
Para el batch 0, la perdida (loss) es   10.44.
Para el batch 1, la perdida (loss) es   10.94.
Para el batch 2, la perdida (loss) es   12.18.
Para el batch 3, la perdida (loss) es   14.24.
Para el batch 4, la perdida (loss) es   16.86.
La perdida promedio para la epoch 2 es   16.86 y el MAE es    3.41.

Epoch 00003: Learning rate is 0.0500.
Para el batch 0, la perdida (loss) es   35.25.
Para el batch 1, la perdida (loss) es   20.86.
Para el batch 2, la perdida (loss) es   15.10.
Para el batch 3, la perdida (loss) es   12.15.
Para el batch 4, la perdida (loss) es   10.47.
La perdida promedio para la epoch 3 es   10.47 y el MAE es    2.31.

Epoch 00004: Learning rate is 0.0500.
Para el batch 0, la perdida (loss) es    4.03.
Para el batch 1, la perdida (loss) es    4.65.
Para el batch 2, la perdida (loss) es    4.61.
Para el batch 3, la perdida (loss) es    4.65.
Para el batch 4, la perdida (loss) es    4.74.
La perdida promedio para la epoch 4 es    4.74 y el MAE es    1.72.

Epoch 00005: Learning rate is 0.0500.
Para el batch 0, la perdida (loss) es    4.02.
Para el batch 1, la perdida (loss) es    3.93.
Para el batch 2, la perdida (loss) es    3.99.
Para el batch 3, la perdida (loss) es    3.83.
Para el batch 4, la perdida (loss) es    4.04.
La perdida promedio para la epoch 5 es    4.04 y el MAE es    1.59.

Epoch 00006: Learning rate is 0.0100.
Para el batch 0, la perdida (loss) es    4.38.
Para el batch 1, la perdida (loss) es    4.90.
Para el batch 2, la perdida (loss) es    4.83.
Para el batch 3, la perdida (loss) es    5.02.
Para el batch 4, la perdida (loss) es    5.04.
La perdida promedio para la epoch 6 es    5.04 y el MAE es    1.77.

Epoch 00007: Learning rate is 0.0100.
Para el batch 0, la perdida (loss) es    5.85.
Para el batch 1, la perdida (loss) es    5.87.
Para el batch 2, la perdida (loss) es    5.11.
Para el batch 3, la perdida (loss) es    5.38.
Para el batch 4, la perdida (loss) es    4.97.
La perdida promedio para la epoch 7 es    4.97 y el MAE es    1.76.

Epoch 00008: Learning rate is 0.0100.
Para el batch 0, la perdida (loss) es    4.39.
Para el batch 1, la perdida (loss) es    3.87.
Para el batch 2, la perdida (loss) es    3.69.
Para el batch 3, la perdida (loss) es    3.95.
Para el batch 4, la perdida (loss) es    4.16.
La perdida promedio para la epoch 8 es    4.16 y el MAE es    1.63.

Epoch 00009: Learning rate is 0.0050.
Para el batch 0, la perdida (loss) es    4.89.
Para el batch 1, la perdida (loss) es    4.67.
Para el batch 2, la perdida (loss) es    4.33.
Para el batch 3, la perdida (loss) es    4.21.
Para el batch 4, la perdida (loss) es    4.04.
La perdida promedio para la epoch 9 es    4.04 y el MAE es    1.61.

Epoch 00010: Learning rate is 0.0050.
Para el batch 0, la perdida (loss) es    4.09.
Para el batch 1, la perdida (loss) es    4.51.
Para el batch 2, la perdida (loss) es    4.56.
Para el batch 3, la perdida (loss) es    4.31.
Para el batch 4, la perdida (loss) es    4.45.
La perdida promedio para la epoch 10 es    4.45 y el MAE es    1.69.

Epoch 00011: Learning rate is 0.0050.
Para el batch 0, la perdida (loss) es    3.31.
Para el batch 1, la perdida (loss) es    3.34.
Para el batch 2, la perdida (loss) es    4.03.
Para el batch 3, la perdida (loss) es    4.05.
Para el batch 4, la perdida (loss) es    4.02.
La perdida promedio para la epoch 11 es    4.02 y el MAE es    1.53.

Epoch 00012: Learning rate is 0.0010.
Para el batch 0, la perdida (loss) es    3.62.
Para el batch 1, la perdida (loss) es    4.47.
Para el batch 2, la perdida (loss) es    4.52.
Para el batch 3, la perdida (loss) es    4.33.
Para el batch 4, la perdida (loss) es    4.42.
La perdida promedio para la epoch 12 es    4.42 y el MAE es    1.67.

Epoch 00013: Learning rate is 0.0010.
Para el batch 0, la perdida (loss) es    4.11.
Para el batch 1, la perdida (loss) es    4.96.
Para el batch 2, la perdida (loss) es    4.48.
Para el batch 3, la perdida (loss) es    4.53.
Para el batch 4, la perdida (loss) es    4.36.
La perdida promedio para la epoch 13 es    4.36 y el MAE es    1.67.

Epoch 00014: Learning rate is 0.0010.
Para el batch 0, la perdida (loss) es    3.64.
Para el batch 1, la perdida (loss) es    3.94.
Para el batch 2, la perdida (loss) es    3.97.
Para el batch 3, la perdida (loss) es    4.12.
Para el batch 4, la perdida (loss) es    4.00.
La perdida promedio para la epoch 14 es    4.00 y el MAE es    1.57.

Callbacks de Keras estandar

Asegurate de revisar los callbacks de Keras preexistentes visitando la documentacion de la api. Las aplicaciones incluyen el registro a CSV, guardar el modelo, visualizar en TensorBoard y mucho mas.

NOTA: La documentacion aun esta en ingles