TensorFlow 2.0 Beta is available Learn more

Multi-worker Training with Keras

View on TensorFlow.org View source on GitHub Download notebook


This tutorial demonstrates multi-worker distributed training with Keras model using tf.distribute.Strategy API. With the help of the strategies specifically designed for multi-worker training, a Keras model that was designed to run on single-worker can seamlessly work on multiple workers with minimal code change.

Distributed Training in TensorFlow guide is available for an overview of the distribution strategies TensorFlow supports for those interested in a deeper understanding of tf.distribute.Strategy APIs.


First, setup TensorFlow and the necessary imports.

from __future__ import absolute_import, division, print_function, unicode_literals
  # %tensorflow_version only exists in Colab.
  !pip install -q tf-nightly-gpu-2.0-preview
except Exception:
import tensorflow_datasets as tfds
import tensorflow as tf
ERROR: tensorflow-gpu 2.0.0b1 has requirement tb-nightly<1.14.0a20190604,>=1.14.0a20190603, but you'll have tb-nightly 1.15.0a20190806 which is incompatible.

Preparing dataset

Now, let's prepare the MNIST dataset from TensorFlow Datasets. The MNIST dataset comprises 60,000 training examples and 10,000 test examples of the handwritten digits 0–9, formatted as 28x28-pixel monochrome images.


# Scaling MNIST data from (0, 255] to (0., 1.]
def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255
  return image, label

datasets, info = tfds.load(name='mnist',

train_datasets_unbatched = datasets['train'].map(scale).shuffle(BUFFER_SIZE)
train_datasets = train_datasets_unbatched.batch(BATCH_SIZE)
WARNING: Logging before flag parsing goes to stderr.
W0813 02:01:44.850780 139676611589888 dataset_builder.py:439] Warning: Setting shuffle_files=True because split=TRAIN and shuffle_files=None. This behavior will be deprecated on 2019-08-06, at which point shuffle_files=False will be the default for all splits.

Build the Keras model

Here we use tf.keras.Sequential API to build and compile a simple convolutional neural networks Keras model to train with our MNIST dataset.

def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
  return model

Let's first try training the model for a small number of epochs and observe the results in single worker to make sure everything works correctly. You should expect to see the loss dropping and accuracy approaching 1.0 as epoch advances.

single_worker_model = build_and_compile_cnn_model()
single_worker_model.fit(x=train_datasets, epochs=3)
W0813 02:01:46.159921 139676611589888 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:466: BaseResourceVariable.constraint (from tensorflow.python.ops.resource_variable_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Apply a constraint manually following the optimizer update step.

Epoch 1/3
938/938 [==============================] - 12s 13ms/step - loss: 2.0466 - accuracy: 0.4304
Epoch 2/3
938/938 [==============================] - 9s 10ms/step - loss: 1.1292 - accuracy: 0.7853
Epoch 3/3
938/938 [==============================] - 9s 9ms/step - loss: 0.6138 - accuracy: 0.8527

<tensorflow.python.keras.callbacks.History at 0x7f082c58b390>

Multi-worker Configuration

Now let's enter the world of multi-worker training. In TensorFlow, TF_CONFIG environment variable is required for training on multiple machines, each of which possibly has a different role. TF_CONFIG is used to specify the cluster configuration on each worker that is part of the cluster.

There are two components of TF_CONFIG: cluster and task. cluster provides information about the training cluster, which is a dict consisting of different types of jobs such as worker. In multi-worker training, there is usually one worker that takes on a little more responsibility like saving checkpoint and writing summary file for TensorBoard in addition to what a regular worker does. Such worker is referred to as the 'chief' worker, and it is customary that the worker with index 0 is appointed as the chief worker (in fact this is how tf.distribute.Strategy is implemented). task on the other hand provides information of the current task.

In this example, we set the task type to "worker" and the task index to 0. This means the machine that has such setting is the first worker, which will be appointed as the chief worker and do more work than other workers. Note that other machines will need to have TF_CONFIG environment variable set as well, and it should have the same cluster dict, but different task type or task index depending on what the roles of those machines are.

For illustration purposes, this tutorial shows how one may set a TF_CONFIG with 2 workers on localhost. In practice, users would create multiple workers on external IP addresses/ports, and set TF_CONFIG on each worker appropriately.

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    'task': {'type': 'worker', 'index': 0}

Note that while the learning rate is fixed in this example, in general it may be necessary to adjust the learning rate based on the global batch size.

Choose the right strategy

In TensorFlow, distributed training consists of synchronous training, where the steps of training are synced across the workers and replicas, and asynchronous training, where the training steps are not strictly synced.

MultiWorkerMirroredStrategy, which is the recommended strategy for synchronous multi-worker training, will be demonstrated in this guide. To train the model, use an instance of tf.distribute.experimental.MultiWorkerMirroredStrategy. MultiWorkerMirroredStrategy creates copies of all variables in the model's layers on each device across all workers. It uses CollectiveOps, a TensorFlow op for collective communication, to aggregate gradients and keep the variables in sync. The tf.distribute.Strategy guide has more details about this strategy.

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
W0813 02:02:16.917112 139676611589888 collective_all_reduce_strategy.py:166] Collective ops is not configured at program startup. Some performance features may not be enabled.

MultiWorkerMirroredStrategy provides multiple implementations via the CollectiveCommunication parameter. RING implements ring-based collectives using gRPC as the cross-host communication layer. NCCL uses Nvidia's NCCL to implement collectives. AUTO defers the choice to the runtime. The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster.

Train the model with MultiWorkerMirroredStrategy

With the integration of tf.distribute.Strategy API into tf.keras, the only change you will make to distribute the training to multi-worker is enclosing the model building and model.compile() call inside strategy.scope(). The distribution strategy's scope dictates how and where the variables are created, and in the case of MultiWorkerMirroredStrategy, the variables created are MirroredVariables, and they are replicated on each of the workers.

# Here the batch size scales up by number of workers since 
# `tf.data.Dataset.batch` expects the global batch size. Previously we used 64, 
# and now this becomes 128.
train_datasets = train_datasets_unbatched.batch(GLOBAL_BATCH_SIZE)
with strategy.scope():
  multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets, epochs=3)
Epoch 1/3
469/469 [==============================] - 9s 20ms/step - loss: 2.2403 - accuracy: 0.1778
Epoch 2/3
469/469 [==============================] - 7s 15ms/step - loss: 2.0074 - accuracy: 0.4766
Epoch 3/3
469/469 [==============================] - 7s 15ms/step - loss: 1.5697 - accuracy: 0.7200

<tensorflow.python.keras.callbacks.History at 0x7f08dfc1d7b8>

Dataset sharding and batch size

In multi-worker training, sharding data into multiple parts is needed to ensure convergence and performance. However, note that in above code snippet, the datasets are directly sent to model.fit() without needing to shard; this is because tf.distribute.Strategy API takes care of the dataset sharding automatically in multi-worker trainings.

If you prefer manual sharding for your training, automatic sharding can be turned off via tf.data.experimental.DistributeOptions api. Concretely,

options = tf.data.Options()
options.experimental_distribute.auto_shard = False
train_datasets_no_auto_shard = train_datasets.with_options(options)

Another thing to notice is the batch size for the datasets. In the code snippet above, we use GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS, which is NUM_WORKERS times as large as the case it was for single worker, because the effective per worker batch size is the global batch size (the parameter passed in tf.data.Dataset.batch()) divided by the number of workers, and with this change we are keeping the per worker batch size same as before.


You now have a Keras model that is all set up to run in multiple workers with MultiWorkerMirroredStrategy. You can try the following techniques to tweak performance of multi-worker training.

  • MultiWorkerMirroredStrategy provides multiple collective communication implementations. RING implements ring-based collectives using gRPC as the cross-host communication layer. NCCL uses Nvidia's NCCL to implement collectives. AUTO defers the choice to the runtime. The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster. To override the automatic choice, specify a valid value to the communication parameter of MultiWorkerMirroredStrategy's constructor, e.g. communication=tf.distribute.experimental.CollectiveCommunication.NCCL.
  • Cast the variables to tf.float if possible. The official ResNet model includes an example of how this can be done.

Fault tolerance

In synchronous training, the cluster would fail if one of the workers fails and no failure-recovery mechanism exists. Using Keras with tf.distribute.Strategy comes with the advantage of fault tolerance in cases where workers die or are otherwise unstable. We do this by preserving training state in the distributed file system of your choice, such that upon restart of the instance that previously failed or preempted, the training state is recovered.

Since all the workers are kept in sync in terms of training epochs and steps, other workers would need to wait for the failed or preempted worker to restart to continue.

ModelCheckpoint callback

To take advantage of fault tolerance in multi-worker training, provide an instance of tf.keras.callbacks.ModelCheckpoint at the tf.keras.Model.fit() call. The callback will store the checkpoint and training state in the directory corresponding to the filepath argument to ModelCheckpoint.

# Replace the `filepath` argument with a path in the file system
# accessible by all workers.
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='/tmp/keras-ckpt')]
with strategy.scope():
  multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets, epochs=3, callbacks=callbacks)
Epoch 1/3
    469/Unknown - 8s 18ms/step - loss: 2.2566 - accuracy: 0.1654
W0813 02:02:51.546285 139676611589888 deprecation.py:506] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1784: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

469/469 [==============================] - 9s 19ms/step - loss: 2.2566 - accuracy: 0.1654
Epoch 2/3
469/469 [==============================] - 8s 16ms/step - loss: 2.1010 - accuracy: 0.4973
Epoch 3/3
469/469 [==============================] - 8s 16ms/step - loss: 1.7733 - accuracy: 0.6656

<tensorflow.python.keras.callbacks.History at 0x7f08df99df98>

If a worker gets preempted, the whole cluster pauses until the preempted worker is restarted. Once the worker rejoins the cluster, other workers will also restart. Now, every worker reads the checkpoint file that was previously saved and picks up its former state, thereby allowing the cluster to get back in sync. Then the training continues.

If you inspect the directory containing the filepath you specified in ModelCheckpoint, you may notice some temporarily generated checkpoint files. Those files are needed for recovering the previously lost instances, and they will be removed by the library at the end of tf.keras.Model.fit() upon successful exiting of your multi-worker training.

See also

  1. Distributed Training in TensorFlow guide provides an overview of the available distribution strategies.
  2. Official ResNet50 model, which can be trained using either MirroredStrategy or MultiWorkerMirroredStrategy.