Entrenamiento distribuido con Keras

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar libreta

Descripción general

La API tf.distribute.Strategy proporciona una abstracción para distribuir su capacitación en varias unidades de procesamiento. Le permite realizar un entrenamiento distribuido utilizando modelos existentes y código de entrenamiento con cambios mínimos.

Este tutorial demuestra cómo usar tf.distribute.MirroredStrategy para realizar la replicación en el gráfico con entrenamiento síncrono en muchas GPU en una máquina . La estrategia esencialmente copia todas las variables del modelo a cada procesador. Luego, usa all-reduce para combinar los gradientes de todos los procesadores y aplica el valor combinado a todas las copias del modelo.

Utilizará las API de tf.keras para crear el modelo y Model.fit para entrenarlo. (Para obtener más información sobre el entrenamiento distribuido con un bucle de entrenamiento personalizado y MirroredStrategy , consulte este tutorial ).

MirroredStrategy entrena su modelo en múltiples GPU en una sola máquina. Para el entrenamiento síncrono en muchas GPU en varios trabajadores , use tf.distribute.MultiWorkerMirroredStrategy con Keras Model.fit o un ciclo de entrenamiento personalizado . Para conocer otras opciones, consulte la guía de capacitación distribuida .

Para conocer otras estrategias, está la guía de capacitación distribuida con TensorFlow .

Configuración

import tensorflow_datasets as tfds
import tensorflow as tf

import os

# Load the TensorBoard notebook extension.
%load_ext tensorboard
print(tf.__version__)
2.8.0-rc1

Descargar el conjunto de datos

Cargue el conjunto de datos MNIST desde TensorFlow Datasets . Esto devuelve un conjunto de datos en formato tf.data .

Establecer el argumento with_info en True incluye los metadatos de todo el conjunto de datos, que se guarda aquí en info . Entre otras cosas, este objeto de metadatos incluye el número de ejemplos de entrenamiento y prueba.

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']

Definir la estrategia de distribución.

Cree un objeto MirroredStrategy . Esto manejará la distribución y proporcionará un administrador de contexto ( MirroredStrategy.scope ) para construir su modelo en el interior.

strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Configurar la canalización de entrada

Al entrenar un modelo con varias GPU, puede usar la potencia informática adicional de manera eficaz aumentando el tamaño del lote. En general, utilice el tamaño de lote más grande que se ajuste a la memoria de la GPU y ajuste la tasa de aprendizaje en consecuencia.

# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

Defina una función que normalice los valores de píxeles de la imagen del rango [0, 255] al rango [0, 1] ( escalado de características ):

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Aplique esta función de scale a los datos de prueba y entrenamiento, y luego use las API de tf.data.Dataset para mezclar los datos de entrenamiento ( Dataset.shuffle ) y procesarlos por lotes ( Dataset.batch ). Tenga en cuenta que también mantiene un caché en memoria de los datos de entrenamiento para mejorar el rendimiento ( Dataset.cache ).

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

Crear el modelo

Cree y compile el modelo Keras en el contexto de Strategy.scope :

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Definir las devoluciones de llamada

Defina los siguientes tf.keras.callbacks :

Para fines ilustrativos, agregue una devolución de llamada personalizada llamada PrintLR para mostrar la tasa de aprendizaje en el cuaderno.

# Define the checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
# Define the name of the checkpoint files.
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Define a function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# Define a callback for printing the learning rate at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))
# Put all the callbacks together.
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

Formar y evaluar

Ahora, entrene el modelo de la manera habitual llamando a Model.fit en el modelo y pasando el conjunto de datos creado al comienzo del tutorial. Este paso es el mismo ya sea que esté distribuyendo la capacitación o no.

EPOCHS = 12

model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks)
2022-01-26 05:38:28.865380: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/12
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
933/938 [============================>.] - ETA: 0s - loss: 0.2029 - accuracy: 0.9399
Learning rate for epoch 1 is 0.0010000000474974513
938/938 [==============================] - 10s 4ms/step - loss: 0.2022 - accuracy: 0.9401 - lr: 0.0010
Epoch 2/12
930/938 [============================>.] - ETA: 0s - loss: 0.0654 - accuracy: 0.9813
Learning rate for epoch 2 is 0.0010000000474974513
938/938 [==============================] - 3s 3ms/step - loss: 0.0652 - accuracy: 0.9813 - lr: 0.0010
Epoch 3/12
931/938 [============================>.] - ETA: 0s - loss: 0.0453 - accuracy: 0.9864
Learning rate for epoch 3 is 0.0010000000474974513
938/938 [==============================] - 3s 3ms/step - loss: 0.0453 - accuracy: 0.9864 - lr: 0.0010
Epoch 4/12
923/938 [============================>.] - ETA: 0s - loss: 0.0246 - accuracy: 0.9933
Learning rate for epoch 4 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0244 - accuracy: 0.9934 - lr: 1.0000e-04
Epoch 5/12
929/938 [============================>.] - ETA: 0s - loss: 0.0211 - accuracy: 0.9944
Learning rate for epoch 5 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0212 - accuracy: 0.9944 - lr: 1.0000e-04
Epoch 6/12
930/938 [============================>.] - ETA: 0s - loss: 0.0192 - accuracy: 0.9950
Learning rate for epoch 6 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0194 - accuracy: 0.9950 - lr: 1.0000e-04
Epoch 7/12
927/938 [============================>.] - ETA: 0s - loss: 0.0179 - accuracy: 0.9953
Learning rate for epoch 7 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0179 - accuracy: 0.9953 - lr: 1.0000e-04
Epoch 8/12
938/938 [==============================] - ETA: 0s - loss: 0.0153 - accuracy: 0.9966
Learning rate for epoch 8 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0153 - accuracy: 0.9966 - lr: 1.0000e-05
Epoch 9/12
927/938 [============================>.] - ETA: 0s - loss: 0.0151 - accuracy: 0.9966
Learning rate for epoch 9 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0150 - accuracy: 0.9966 - lr: 1.0000e-05
Epoch 10/12
935/938 [============================>.] - ETA: 0s - loss: 0.0148 - accuracy: 0.9966
Learning rate for epoch 10 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0148 - accuracy: 0.9966 - lr: 1.0000e-05
Epoch 11/12
937/938 [============================>.] - ETA: 0s - loss: 0.0146 - accuracy: 0.9967
Learning rate for epoch 11 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0146 - accuracy: 0.9967 - lr: 1.0000e-05
Epoch 12/12
926/938 [============================>.] - ETA: 0s - loss: 0.0145 - accuracy: 0.9967
Learning rate for epoch 12 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0144 - accuracy: 0.9967 - lr: 1.0000e-05
<keras.callbacks.History at 0x7fad70067c10>

Compruebe los puntos de control guardados:

# Check the checkpoint directory.
ls {checkpoint_dir}
checkpoint           ckpt_4.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_4.index
ckpt_1.index             ckpt_5.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_5.index
ckpt_10.index            ckpt_6.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_6.index
ckpt_11.index            ckpt_7.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_7.index
ckpt_12.index            ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index             ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index

Para verificar qué tan bien funciona el modelo, cargue el último punto de control y llame a Model.evaluate en los datos de prueba:

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc))
2022-01-26 05:39:15.260539: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
157/157 [==============================] - 2s 4ms/step - loss: 0.0373 - accuracy: 0.9879
Eval loss: 0.03732967749238014, Eval accuracy: 0.9879000186920166

Para visualizar el resultado, inicie TensorBoard y vea los registros:

%tensorboard --logdir=logs

ls -sh ./logs
total 4.0K
4.0K train

Exportar a modelo guardado

Exporte el gráfico y las variables al formato de modelo guardado independiente de la plataforma mediante Model.save . Después de guardar su modelo, puede cargarlo con o sin Strategy.scope .

path = 'saved_model/'
model.save(path, save_format='tf')
2022-01-26 05:39:18.012847: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: saved_model/assets
INFO:tensorflow:Assets written to: saved_model/assets

Ahora, cargue el modelo sin Strategy.scope :

unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 2ms/step - loss: 0.0373 - accuracy: 0.9879
Eval loss: 0.03732967749238014, Eval Accuracy: 0.9879000186920166

Cargue el modelo con Strategy.scope :

with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
2022-01-26 05:39:19.489971: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
157/157 [==============================] - 3s 3ms/step - loss: 0.0373 - accuracy: 0.9879
Eval loss: 0.03732967749238014, Eval Accuracy: 0.9879000186920166

Recursos adicionales

Más ejemplos que utilizan diferentes estrategias de distribución con la API Keras Model.fit :

  1. El tutorial Resolver tareas de GLUE con BERT en TPU utiliza tf.distribute.MirroredStrategy para entrenar en GPU y tf.distribute.TPUStrategy en TPU.
  2. El tutorial Guardar y cargar un modelo usando una estrategia de distribución demuestra cómo usar las API de modelo guardado con tf.distribute.Strategy .
  3. Los modelos oficiales de TensorFlow se pueden configurar para ejecutar múltiples estrategias de distribución.

Para obtener más información sobre las estrategias de distribución de TensorFlow:

  1. El tutorial Entrenamiento personalizado con tf.distribute.Strategy muestra cómo usar tf.distribute.MirroredStrategy para el entrenamiento de un solo trabajador con un ciclo de entrenamiento personalizado.
  2. El tutorial de capacitación para varios trabajadores con Keras muestra cómo usar MultiWorkerMirroredStrategy con Model.fit .
  3. El tutorial de bucle de entrenamiento personalizado con Keras y MultiWorkerMirroredStrategy muestra cómo usar MultiWorkerMirroredStrategy con Keras y un bucle de entrenamiento personalizado.
  4. La guía de capacitación distribuida en TensorFlow proporciona una descripción general de las estrategias de distribución disponibles.
  5. La guía Mejor rendimiento con tf.function brinda información sobre otras estrategias y herramientas, como TensorFlow Profiler , que puede usar para optimizar el rendimiento de sus modelos de TensorFlow.