Missed TensorFlow World? Check out the recap. Learn more

Distributed training with TensorFlow

View on TensorFlow.org View source on GitHub Download notebook

Overview

tf.distribute.Strategy is a TensorFlow API to distribute training across multiple GPUs, multiple machines or TPUs. Using this API, you can distribute your existing models and training code with minimal code changes.

tf.distribute.Strategy has been designed with these key goals in mind:

  • Easy to use and support multiple user segments, including researchers, ML engineers, etc.
  • Provide good performance out of the box.
  • Easy switching between strategies.

tf.distribute.Strategy can be used with a high-level API like Keras, and can also be used to distribute custom training loops (and, in general, any computation using TensorFlow).

In TensorFlow 2.0, you can execute your programs eagerly, or in a graph using tf.function. tf.distribute.Strategy intends to support both these modes of execution. Although we discuss training most of the time in this guide, this API can also be used for distributing evaluation and prediction on different platforms.

You can use tf.distribute.Strategy with very few changes to your code, because we have changed the underlying components of TensorFlow to become strategy-aware. This includes variables, layers, models, optimizers, metrics, summaries, and checkpoints.

In this guide, we explain various types of strategies and how you can use them in different situations.

# Import TensorFlow
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf

Types of strategies

tf.distribute.Strategy intends to cover a number of use cases along different axes. Some of these combinations are currently supported and others will be added in the future. Some of these axes are:

  • Synchronous vs asynchronous training: These are two common ways of distributing training with data parallelism. In sync training, all workers train over different slices of input data in sync, and aggregating gradients at each step. In async training, all workers are independently training over the input data and updating variables asynchronously. Typically sync training is supported via all-reduce and async through parameter server architecture.
  • Hardware platform: You may want to scale your training onto multiple GPUs on one machine, or multiple machines in a network (with 0 or more GPUs each), or on Cloud TPUs.

In order to support these use cases, there are six strategies available. In the next section we explain which of these are supported in which scenarios in TF 2.0 at this time. Here is a quick overview:

Training API MirroredStrategy TPUStrategy MultiWorkerMirroredStrategy CentralStorageStrategy ParameterServerStrategy OneDeviceStrategy
Keras API Supported Experimental support Experimental support Experimental support Supported planned post 2.0 Supported
Custom training loop Experimental support Experimental support Support planned post 2.0 Support planned post 2.0 No support yet Supported
Estimator API Limited Support Not supported Limited Support Limited Support Limited Support Limited Support

MirroredStrategy

tf.distribute.MirroredStrategy supports synchronous distributed training on multiple GPUs on one machine. It creates one replica per GPU device. Each variable in the model is mirrored across all the replicas. Together, these variables form a single conceptual variable called MirroredVariable. These variables are kept in sync with each other by applying identical updates.

Efficient all-reduce algorithms are used to communicate the variable updates across the devices. All-reduce aggregates tensors across all the devices by adding them up, and makes them available on each device. It’s a fused algorithm that is very efficient and can reduce the overhead of synchronization significantly. There are many all-reduce algorithms and implementations available, depending on the type of communication available between devices. By default, it uses NVIDIA NCCL as the all-reduce implementation. You can choose from a few other options we provide, or write your own.

Here is the simplest way of creating MirroredStrategy:

mirrored_strategy = tf.distribute.MirroredStrategy()

This will create a MirroredStrategy instance which will use all the GPUs that are visible to TensorFlow, and use NCCL as the cross device communication.

If you wish to use only some of the GPUs on your machine, you can do so like this:

mirrored_strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1"])
WARNING:tensorflow:Some requested devices in `tf.distribute.Strategy` are not visible to TensorFlow: /job:localhost/replica:0/task:0/device:GPU:0,/job:localhost/replica:0/task:0/device:GPU:1

If you wish to override the cross device communication, you can do so using the cross_device_ops argument by supplying an instance of tf.distribute.CrossDeviceOps. Currently, tf.distribute.HierarchicalCopyAllReduce and tf.distribute.ReductionToOneDevice are two options other than tf.distribute.NcclAllReduce which is the default.

mirrored_strategy = tf.distribute.MirroredStrategy(
    cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())

CentralStorageStrategy

tf.distribute.experimental.CentralStorageStrategy does synchronous training as well. Variables are not mirrored, instead they are placed on the CPU and operations are replicated across all local GPUs. If there is only one GPU, all variables and operations will be placed on that GPU.

Create an instance of CentralStorageStrategy by:

central_storage_strategy = tf.distribute.experimental.CentralStorageStrategy()
INFO:tensorflow:ParameterServerStrategy with compute_devices = ('/device:GPU:0',), variable_device = '/device:GPU:0'

This will create a CentralStorageStrategy instance which will use all visible GPUs and CPU. Update to variables on replicas will be aggregated before being applied to variables.

MultiWorkerMirroredStrategy

tf.distribute.experimental.MultiWorkerMirroredStrategy is very similar to MirroredStrategy. It implements synchronous distributed training across multiple workers, each with potentially multiple GPUs. Similar to MirroredStrategy, it creates copies of all variables in the model on each device across all workers.

It uses CollectiveOps as the multi-worker all-reduce communication method used to keep variables in sync. A collective op is a single op in the TensorFlow graph which can automatically choose an all-reduce algorithm in the TensorFlow runtime according to hardware, network topology and tensor sizes.

It also implements additional performance optimizations. For example, it includes a static optimization that converts multiple all-reductions on small tensors into fewer all-reductions on larger tensors. In addition, we are designing it to have a plugin architecture - so that in the future, you will be able to plugin algorithms that are better tuned for your hardware. Note that collective ops also implement other collective operations such as broadcast and all-gather.

Here is the simplest way of creating MultiWorkerMirroredStrategy:

multiworker_strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Single-worker CollectiveAllReduceStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.AUTO

MultiWorkerMirroredStrategy currently allows you to choose between two different implementations of collective ops. CollectiveCommunication.RING implements ring-based collectives using gRPC as the communication layer. CollectiveCommunication.NCCL uses Nvidia's NCCL to implement collectives. CollectiveCommunication.AUTO defers the choice to the runtime. The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster. You can specify them in the following way:

multiworker_strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
    tf.distribute.experimental.CollectiveCommunication.NCCL)
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Single-worker CollectiveAllReduceStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.NCCL

One of the key differences to get multi worker training going, as compared to multi-GPU training, is the multi-worker setup. The TF_CONFIG environment variable is the standard way in TensorFlow to specify the cluster configuration to each worker that is part of the cluster. Learn more about setting up TF_CONFIG.

TPUStrategy

tf.distribute.experimental.TPUStrategy lets you run your TensorFlow training on Tensor Processing Units (TPUs). TPUs are Google's specialized ASICs designed to dramatically accelerate machine learning workloads. They are available on Google Colab, the TensorFlow Research Cloud and Cloud TPU.

In terms of distributed training architecture, TPUStrategy is the same MirroredStrategy - it implements synchronous distributed training. TPUs provide their own implementation of efficient all-reduce and other collective operations across multiple TPU cores, which are used in TPUStrategy.

Here is how you would instantiate TPUStrategy:

cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
    tpu=tpu_address)
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)

The TPUClusterResolver instance helps locate the TPUs. In Colab, you don't need to specify any arguments to it.

If you want to use this for Cloud TPUs: - You must specify the name of your TPU resource in the tpu argument. - You must initialize the tpu system explicitly at the start of the program. This is required before TPUs can be used for computation. Initializing the tpu system also wipes out the TPU memory, so it's important to complete this step first in order to avoid losing state.

ParameterServerStrategy

tf.distribute.experimental.ParameterServerStrategy supports parameter servers training on multiple machines. In this setup, some machines are designated as workers and some as parameter servers. Each variable of the model is placed on one parameter server. Computation is replicated across all GPUs of all the workers.

In terms of code, it looks similar to other strategies:

ps_strategy = tf.distribute.experimental.ParameterServerStrategy()

For multi worker training, TF_CONFIG needs to specify the configuration of parameter servers and workers in your cluster, which you can read more about in TF_CONFIG below below.

OneDeviceStrategy

tf.distribute.OneDeviceStrategy runs on a single device. This strategy will place any variables created in its scope on the specified device. Input distributed through this strategy will be prefetched to the specified device. Moreover, any functions called via strategy.experimental_run_v2 will also be placed on the specified device.

You can use this strategy to test your code before switching to other strategies which actually distributes to multiple devices/machines.

strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")

So far we've talked about what are the different strategies available and how you can instantiate them. In the next few sections, we will talk about the different ways in which you can use them to distribute your training. We will show short code snippets in this guide and link off to full tutorials which you can run end to end.

Using tf.distribute.Strategy with Keras

We've integrated tf.distribute.Strategy into tf.keras which is TensorFlow's implementation of the Keras API specification. tf.keras is a high-level API to build and train models. By integrating into tf.keras backend, we've made it seamless for you to distribute your training written in the Keras training framework.

Here's what you need to change in your code:

  1. Create an instance of the appropriate tf.distribute.Strategy
  2. Move the creation and compiling of Keras model inside strategy.scope.

We support all types of Keras models - sequential, functional and subclassed.

Here is a snippet of code to do this for a very simple Keras model with one dense layer:

mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
  model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
  model.compile(loss='mse', optimizer='sgd')

In this example we used MirroredStrategy so we can run this on a machine with multiple GPUs. strategy.scope() indicated which parts of the code to run distributed. Creating a model inside this scope allows us to create mirrored variables instead of regular variables. Compiling under the scope allows us to know that the user intends to train this model using this strategy. Once this is set up, you can fit your model like you would normally. MirroredStrategy takes care of replicating the model's training on the available GPUs, aggregating gradients, and more.

dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(10)
model.fit(dataset, epochs=2)
model.evaluate(dataset)
Epoch 1/2
10/10 [==============================] - 2s 157ms/step - loss: 0.1410
Epoch 2/2
10/10 [==============================] - 0s 2ms/step - loss: 0.0644
10/10 [==============================] - 1s 111ms/step - loss: 0.0387

0.038715437054634094

Here we used a tf.data.Dataset to provide the training and eval input. You can also use numpy arrays:

import numpy as np
inputs, targets = np.ones((100, 1)), np.ones((100, 1))
model.fit(inputs, targets, epochs=2, batch_size=10)
Train on 100 samples
Epoch 1/2
100/100 [==============================] - 0s 2ms/sample - loss: 0.0276
Epoch 2/2
100/100 [==============================] - 0s 142us/sample - loss: 0.0122

<tensorflow.python.keras.callbacks.History at 0x7f6de295a630>

In both cases (dataset or numpy), each batch of the given input is divided equally among the multiple replicas. For instance, if using MirroredStrategy with 2 GPUs, each batch of size 10 will get divided among the 2 GPUs, with each receiving 5 input examples in each step. Each epoch will then train faster as you add more GPUs. Typically, you would want to increase your batch size as you add more accelerators so as to make effective use of the extra computing power. You will also need to re-tune your learning rate, depending on the model. You can use strategy.num_replicas_in_sync to get the number of replicas.

# Compute global batch size using number of replicas.
BATCH_SIZE_PER_REPLICA = 5
global_batch_size = (BATCH_SIZE_PER_REPLICA *
                     mirrored_strategy.num_replicas_in_sync)
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)
dataset = dataset.batch(global_batch_size)

LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15}
learning_rate = LEARNING_RATES_BY_BATCH_SIZE[global_batch_size]

What's supported now?

In TF 2.0 release, MirroredStrategy, TPUStrategy, CentralStorageStrategy and MultiWorkerMirroredStrategy are supported in Keras. Except MirroredStrategy, others are currently experimental and are subject to change. Support for other strategies will be coming soon. The API and how to use will be exactly the same as above.

Training API MirroredStrategy TPUStrategy MultiWorkerMirroredStrategy CentralStorageStrategy ParameterServerStrategy OneDeviceStrategy
Keras APIs Supported Experimental support Experimental support Experimental support Support planned post 2.0 Supported

Examples and Tutorials

Here is a list of tutorials and examples that illustrate the above integration end to end with Keras:

  1. Tutorial to train MNIST with MirroredStrategy.
  2. Official ResNet50 training with ImageNet data using MirroredStrategy.
  3. ResNet50 trained with Imagenet data on Cloud TPUs with TPUStrategy.
  4. Tutorial to train MNIST using MultiWorkerMirroredStrategy.
  5. NCF trained using MirroredStrategy.
  6. Transformer trained using MirroredStrategy.

Using tf.distribute.Strategy with custom training loops

As you've seen, using tf.distribute.Strategy with high-level APIs (Estimator and Keras) requires changing only a couple lines of your code. With a little more effort, you can also use tf.distribute.Strategy with custom training loops.

If you need more flexibility and control over your training loops than is possible with Estimator or Keras, you can write custom training loops. For instance, when using a GAN, you may want to take a different number of generator or discriminator steps each round. Similarly, the high level frameworks are not very suitable for Reinforcement Learning training.

To support custom training loops, we provide a core set of methods through the tf.distribute.Strategy classes. Using these may require minor restructuring of the code initially, but once that is done, you should be able to switch between GPUs, TPUs, and multiple machines simply by changing the strategy instance.

Here we will show a brief snippet illustrating this use case for a simple training example using the same Keras model as before.

First, we create the model and optimizer inside the strategy's scope. This ensures that any variables created with the model and optimizer are mirrored variables.

with mirrored_strategy.scope():
  model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
  optimizer = tf.keras.optimizers.SGD()

Next, we create the input dataset and call tf.distribute.Strategy.experimental_distribute_dataset to distribute the dataset based on the strategy.

dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(1000).batch(
    global_batch_size)
dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)

Then, we define one step of the training. We will use tf.GradientTape to compute gradients and optimizer to apply those gradients to update our model's variables. To distribute this training step, we put in a function step_fn and pass it to tf.distrbute.Strategy.experimental_run_v2 along with the dataset inputs that we get from dist_dataset created before:

@tf.function
def train_step(dist_inputs):
  def step_fn(inputs):
    features, labels = inputs

    with tf.GradientTape() as tape:
      logits = model(features)
      cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
          logits=logits, labels=labels)
      loss = tf.reduce_sum(cross_entropy) * (1.0 / global_batch_size)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    return cross_entropy

  per_example_losses = mirrored_strategy.experimental_run_v2(
      step_fn, args=(dist_inputs,))
  mean_loss = mirrored_strategy.reduce(
      tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)
  return mean_loss

A few other things to note in the code above:

  1. We used tf.nn.softmax_cross_entropy_with_logits to compute the loss. And then we scaled the total loss by the global batch size. This is important because all the replicas are training in sync and number of examples in each step of training is the global batch. So the loss needs to be divided by the global batch size and not by the replica (local) batch size.
  2. We used the tf.distribute.Strategy.reduce API to aggregate the results returned by tf.distribute.Strategy.experimental_run_v2. tf.distribute.Strategy.experimental_run_v2 returns results from each local replica in the strategy, and there are multiple ways to consume this result. You can reduce them to get an aggregated value. You can also do tf.distribute.Strategy.experimental_local_results to get the list of values contained in the result, one per local replica.
  3. When apply_gradients is called within a distribution strategy scope, its behavior is modified. Specifically, before applying gradients on each parallel instance during synchronous training, it performs a sum-over-all-replicas of the gradients.

Finally, once we have defined the training step, we can iterate over dist_dataset and run the training in a loop:

with mirrored_strategy.scope():
  for inputs in dist_dataset:
    print(train_step(inputs))
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)

In the example above, we iterated over the dist_dataset to provide input to your training. We also provide the tf.distribute.Strategy.make_experimental_numpy_dataset to support numpy inputs. You can use this API to create a dataset before calling tf.distribute.Strategy.experimental_distribute_dataset.

Another way of iterating over your data is to explicitly use iterators. You may want to do this when you want to run for a given number of steps as opposed to iterating over the entire dataset. The above iteration would now be modified to first create an iterator and then explicitly call next on it to get the input data.

with mirrored_strategy.scope():
  iterator = iter(dist_dataset)
  for _ in range(10):
    print(train_step(next(iterator)))
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)

This covers the simplest case of using tf.distribute.Strategy API to distribute custom training loops. We are in the process of improving these APIs. Since this use case requires more work to adapt your code, we will be publishing a separate detailed guide in the future.

What's supported now?

In TF 2.0 release, training with custom training loops is supported using MirroredStrategy as shown above and TPUStrategy. MultiWorkerMirorredStrategy support will be coming in the future.

Training API MirroredStrategy TPUStrategy MultiWorkerMirroredStrategy CentralStorageStrategy ParameterServerStrategy OneDeviceStrategy
Custom Training Loop Experimental support Experimental support Support planned post 2.0 Support planned post 2.0 No support yet Supported

Examples and Tutorials

Here are some examples for using distribution strategy with custom training loops:

  1. Tutorial to train MNIST using MirroredStrategy.
  2. DenseNet example using MirroredStrategy.
  3. BERT example trained using MirroredStrategy and TPUStrategy. This example is particularly helpful for understanding how to load from a checkpoint and generate periodic checkpoints during distributed training etc.
  4. NCF example trained using MirroredStrategy and TPUStrategy that can be enabled using the keras_use_ctl flag.
  5. NMT example trained using MirroredStrategy.

Using tf.distribute.Strategy with Estimator (Limited support)

tf.estimator is a distributed training TensorFlow API that originally supported the async parameter server approach. Like with Keras, we've integrated tf.distribute.Strategy into tf.Estimator. If you're using Estimator for your training, you can easily change to distributed training with very few changes to your code. With this, Estimator users can now do synchronous distributed training on multiple GPUs and multiple workers, as well as use TPUs. This support in Estimator is, however, limited. See What's supported now section below for more details.

The usage of tf.distribute.Strategy with Estimator is slightly different than the Keras case. Instead of using strategy.scope, now we pass the strategy object into the RunConfig for the Estimator.

Here is a snippet of code that shows this with a premade Estimator LinearRegressor and MirroredStrategy:

mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
    feature_columns=[tf.feature_column.numeric_column('feats')],
    optimizer='SGD',
    config=config)
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpre_reshs
INFO:tensorflow:Using config: {'_is_chief': True, '_distribute_coordinator_mode': None, '_save_checkpoints_secs': 600, '_service': None, '_experimental_max_worker_delay_secs': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f6d50231c18>, '_task_id': 0, '_master': '', '_global_id_in_cluster': 0, '_num_ps_replicas': 0, '_protocol': None, '_log_step_count_steps': 100, '_keep_checkpoint_max': 5, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_evaluation_master': '', '_session_creation_timeout_secs': 7200, '_device_fn': None, '_task_type': 'worker', '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f6df85b7a58>, '_save_summary_steps': 100, '_experimental_distribute': None, '_save_checkpoints_steps': None, '_tf_random_seed': None, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f6d50231c18>, '_keep_checkpoint_every_n_hours': 10000, '_model_dir': '/tmp/tmpre_reshs', '_num_worker_replicas': 1}

We use a premade Estimator here, but the same code works with a custom Estimator as well. train_distribute determines how training will be distributed, and eval_distribute determines how evaluation will be distributed. This is another difference from Keras where we use the same strategy for both training and eval.

Now we can train and evaluate this Estimator with an input function:

def input_fn():
  dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
  return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/feature_column/feature_column_v2.py:518: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:308: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.cast` instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpre_reshs/model.ckpt.
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpre_reshs/model.ckpt.
INFO:tensorflow:Loss for final step: 2.877698e-13.
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2019-11-02T01:22:52Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpre_reshs/model.ckpt-10
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Finished evaluation at 2019-11-02-01:22:52
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpre_reshs/model.ckpt-10

{'average_loss': 1.4210855e-14,
 'global_step': 10,
 'label/mean': 1.0,
 'loss': 1.4210855e-14,
 'prediction/mean': 0.99999994}

Another difference to highlight here between Estimator and Keras is the input handling. In Keras, we mentioned that each batch of the dataset is split automatically across the multiple replicas. In Estimator, however, we do not do automatic splitting of batch, nor automatically shard the data across different workers. You have full control over how you want your data to be distributed across workers and devices, and you must provide an input_fn to specify how to distribute your data.

Your input_fn is called once per worker, thus giving one dataset per worker. Then one batch from that dataset is fed to one replica on that worker, thereby consuming N batches for N replicas on 1 worker. In other words, the dataset returned by the input_fn should provide batches of size PER_REPLICA_BATCH_SIZE. And the global batch size for a step can be obtained as PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync.

When doing multi worker training, you should either split your data across the workers, or shuffle with a random seed on each. You can see an example of how to do this in the Multi-worker Training with Estimator.

We showed an example of using MirroredStrategy with Estimator. You can also use TPUStrategy with Estimator as well, in the exact same way:

config = tf.estimator.RunConfig(
    train_distribute=tpu_strategy, eval_distribute=tpu_strategy)

And similarly, you can use multi worker and parameter server strategies as well. The code remains the same, but you need to use tf.estimator.train_and_evaluate, and set TF_CONFIG environment variables for each binary running in your cluster.

What's supported now?

In TF 2.0 release, there is limited support for training with Estimator using all strategies except TPUStrategy. Basic training and evaluation should work, but a number of advanced features such as scaffold do not yet work. There may also be a number of bugs in this integration. At this time, we do not plan to actively improve this support, and instead are focused on Keras and custom training loop support. If at all possible, you should prefer to use tf.distribute with those APIs instead.

Training API MirroredStrategy TPUStrategy MultiWorkerMirroredStrategy CentralStorageStrategy ParameterServerStrategy OneDeviceStrategy
Estimator API Limited Support Not supported Limited Support Limited Support Limited Support Limited Support

Examples and Tutorials

Here are some examples that show end to end usage of various strategies with Estimator:

  1. Multi-worker Training with Estimator to train MNIST with multiple workers using MultiWorkerMirroredStrategy.
  2. End to end example for multi worker training in tensorflow/ecosystem using Kubernetes templates. This example starts with a Keras model and converts it to an Estimator using the tf.keras.estimator.model_to_estimator API.
  3. Official ResNet50 model, which can be trained using either MirroredStrategy or MultiWorkerMirroredStrategy.

Other topics

In this section, we will cover some topics that are relevant to multiple use cases.

Setting up TF_CONFIG environment variable

For multi-worker training, as mentioned before, you need to set TF_CONFIG environment variable for each binary running in your cluster. The TF_CONFIG environment variable is a JSON string which specifies what tasks constitute a cluster, their addresses and each task's role in the cluster. We provide a Kubernetes template in the tensorflow/ecosystem repo which sets TF_CONFIG for your training tasks.

One example of TF_CONFIG is:

os.environ["TF_CONFIG"] = json.dumps({
    "cluster": {
        "worker": ["host1:port", "host2:port", "host3:port"],
        "ps": ["host4:port", "host5:port"]
    },
   "task": {"type": "worker", "index": 1}
})

This TF_CONFIG specifies that there are three workers and two ps tasks in the cluster along with their hosts and ports. The "task" part specifies that the role of the current task in the cluster, worker 1 (the second worker). Valid roles in a cluster is "chief", "worker", "ps" and "evaluator". There should be no "ps" job except when using tf.distribute.experimental.ParameterServerStrategy.

What's next?

tf.distribute.Strategy is actively under development. We welcome you to try it out and provide and your feedback using GitHub issues.