TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

Distributed training with Keras

View on TensorFlow.org View source on GitHub Download notebook

Overview

The tf.distribute.Strategy API provides an abstraction for distributing your training across multiple processing units. The goal is to allow users to enable distributed training using existing models and training code, with minimal changes.

This tutorial uses the tf.distribute.MirroredStrategy, which does in-graph replication with synchronous training on many GPUs on one machine. Essentially, it copies all of the model's variables to each processor. Then, it uses all-reduce to combine the gradients from all processors and applies the combined value to all copies of the model.

MirroredStategy is one of several distribution strategy available in TensorFlow core. You can read about more strategies at distribution strategy guide.

Keras API

This example uses the tf.keras API to build the model and training loop. For custom training loops, see the tf.distribute.Strategy with training loops tutorial.

Import dependencies

from __future__ import absolute_import, division, print_function, unicode_literals

# Import TensorFlow and TensorFlow Datasets
try:
  !pip install -q tf-nightly
except Exception:
  pass

import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

import os
print(tf.__version__)
2.1.0-dev20191004

Download the dataset

Download the MNIST dataset and load it from TensorFlow Datasets. This returns a dataset in tf.data format.

Setting with_info to True includes the metadata for the entire dataset, which is being saved here to info. Among other things, this metadata object includes the number of train and test examples.

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist (11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/1.0.0...

/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)

WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/1.0.0. Subsequent calls will reuse this data.

Define distribution strategy

Create a MirroredStrategy object. This will handle distribution, and provides a context manager (tf.distribute.MirroredStrategy.scope) to build your model inside.

strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Setup input pipeline

When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly.

# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

Pixel values, which are 0-255, have to be normalized to the 0-1 range. Define this scale in a function.

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Apply this function to the training and test data, shuffle the training data, and batch it for training. Notice we are also keeping an in-memory cache of the training data to improve performance.

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

Create the model

Create and compile the Keras model in the context of strategy.scope.

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
  ])

  model.compile(loss='sparse_categorical_crossentropy',
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Define the callbacks

The callbacks used here are:

  • TensorBoard: This callback writes a log for TensorBoard which allows you to visualize the graphs.
  • Model Checkpoint: This callback saves the model after every epoch.
  • Learning Rate Scheduler: Using this callback, you can schedule the learning rate to change after every epoch/batch.

For illustrative purposes, add a print callback to display the learning rate in the notebook.

# Define the checkpoint directory to store the checkpoints

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

Train and evaluate

Now, train the model in the usual way, calling fit on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.

model.fit(train_dataset, epochs=12, callbacks=callbacks)
Epoch 1/12
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

    938/Unknown - 10s 10ms/step - loss: 0.1922 - accuracy: 0.9441
Learning rate for epoch 1 is 0.0010000000474974513
938/938 [==============================] - 10s 10ms/step - loss: 0.1922 - accuracy: 0.9441
Epoch 2/12
923/938 [============================>.] - ETA: 0s - loss: 0.0662 - accuracy: 0.9803
Learning rate for epoch 2 is 0.0010000000474974513
938/938 [==============================] - 3s 3ms/step - loss: 0.0658 - accuracy: 0.9804
Epoch 3/12
935/938 [============================>.] - ETA: 0s - loss: 0.0456 - accuracy: 0.9863
Learning rate for epoch 3 is 0.0010000000474974513
938/938 [==============================] - 3s 3ms/step - loss: 0.0455 - accuracy: 0.9863
Epoch 4/12
935/938 [============================>.] - ETA: 0s - loss: 0.0264 - accuracy: 0.9926
Learning rate for epoch 4 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0264 - accuracy: 0.9926
Epoch 5/12
932/938 [============================>.] - ETA: 0s - loss: 0.0226 - accuracy: 0.9939
Learning rate for epoch 5 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0225 - accuracy: 0.9940
Epoch 6/12
931/938 [============================>.] - ETA: 0s - loss: 0.0207 - accuracy: 0.9945
Learning rate for epoch 6 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0207 - accuracy: 0.9945
Epoch 7/12
932/938 [============================>.] - ETA: 0s - loss: 0.0191 - accuracy: 0.9951
Learning rate for epoch 7 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0191 - accuracy: 0.9951
Epoch 8/12
931/938 [============================>.] - ETA: 0s - loss: 0.0165 - accuracy: 0.9963
Learning rate for epoch 8 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0165 - accuracy: 0.9962
Epoch 9/12
922/938 [============================>.] - ETA: 0s - loss: 0.0162 - accuracy: 0.9963
Learning rate for epoch 9 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0162 - accuracy: 0.9963
Epoch 10/12
927/938 [============================>.] - ETA: 0s - loss: 0.0160 - accuracy: 0.9965
Learning rate for epoch 10 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0160 - accuracy: 0.9965
Epoch 11/12
934/938 [============================>.] - ETA: 0s - loss: 0.0158 - accuracy: 0.9965
Learning rate for epoch 11 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0158 - accuracy: 0.9965
Epoch 12/12
928/938 [============================>.] - ETA: 0s - loss: 0.0157 - accuracy: 0.9964
Learning rate for epoch 12 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - loss: 0.0156 - accuracy: 0.9965

<tensorflow.python.keras.callbacks.History at 0x7f92ad88d278>

As you can see below, the checkpoints are getting saved.

# check the checkpoint directory
!ls {checkpoint_dir}
checkpoint           ckpt_4.data-00000-of-00002
ckpt_1.data-00000-of-00002   ckpt_4.data-00001-of-00002
ckpt_1.data-00001-of-00002   ckpt_4.index
ckpt_1.index             ckpt_5.data-00000-of-00002
ckpt_10.data-00000-of-00002  ckpt_5.data-00001-of-00002
ckpt_10.data-00001-of-00002  ckpt_5.index
ckpt_10.index            ckpt_6.data-00000-of-00002
ckpt_11.data-00000-of-00002  ckpt_6.data-00001-of-00002
ckpt_11.data-00001-of-00002  ckpt_6.index
ckpt_11.index            ckpt_7.data-00000-of-00002
ckpt_12.data-00000-of-00002  ckpt_7.data-00001-of-00002
ckpt_12.data-00001-of-00002  ckpt_7.index
ckpt_12.index            ckpt_8.data-00000-of-00002
ckpt_2.data-00000-of-00002   ckpt_8.data-00001-of-00002
ckpt_2.data-00001-of-00002   ckpt_8.index
ckpt_2.index             ckpt_9.data-00000-of-00002
ckpt_3.data-00000-of-00002   ckpt_9.data-00001-of-00002
ckpt_3.data-00001-of-00002   ckpt_9.index
ckpt_3.index

To see how the model perform, load the latest checkpoint and call evaluate on the test data.

Call evaluate as before using appropriate datasets.

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

    157/Unknown - 2s 14ms/step - loss: 0.0363 - accuracy: 0.9875Eval loss: 0.03630251882658607, Eval Accuracy: 0.987500011920929

To see the output, you can download and view the TensorBoard logs at the terminal.

$ tensorboard --logdir=path/to/log-directory
!ls -sh ./logs
total 4.0K
4.0K train

Export to SavedModel

Export the graph and the variables to the platform-agnostic SavedModel format. After your model is saved, you can load it with or without the scope.

path = 'saved_model/'
model.save(path, save_format='tf')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Assets written to: saved_model/assets

INFO:tensorflow:Assets written to: saved_model/assets

Load the model without strategy.scope.

unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
    157/Unknown - 1s 7ms/step - loss: 0.0363 - accuracy: 0.9875Eval loss: 0.036302936971143104, Eval Accuracy: 0.987500011920929

Load the model with strategy.scope.

with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss='sparse_categorical_crossentropy',
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
    157/Unknown - 2s 14ms/step - loss: 0.0363 - accuracy: 0.9875Eval loss: 0.036302936971143104, Eval Accuracy: 0.987500011920929

Examples and Tutorials

Here are some examples for using distribution strategy with keras fit/compile: 1. Transformer example trained using tf.distribute.MirroredStrategy 2. NCF example trained using tf.distribute.MirroredStrategy.

More examples listed in the Distribution strategy guide

Next steps