Module: tf.distribute

TensorFlow 1 version

Library for running a computation across multiple devices.

See the guide for overview and examples: TensorFlow v2.x, TensorFlow v1.x.

The intent of this library is that you can write an algorithm in a stylized way and it will be usable with a variety of different tf.distribute.Strategy implementations. Each descendant will implement a different strategy for distributing the algorithm across multiple devices/machines. Furthermore, these changes can be hidden inside the specific layers and other library classes that need special treatment to run in a distributed setting, so that most users' model definition code can run unchanged. The tf.distribute.Strategy API works the same way with eager and graph execution.


  • Data parallelism is where we run multiple copies of the model on different slices of the input data. This is in contrast to model parallelism where we divide up a single copy of a model across multiple devices. Note: we only support data parallelism for now, but hope to add support for model parallelism in the future.
  • A device is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that TensorFlow can run operations on (see e.g. tf.device). You may have multiple devices on a single machine, or be connected to devices on multiple machines. Devices used to run computations are called worker devices. Devices used to store variables are parameter devices. For some strategies, such as tf.distribute.MirroredStrategy, the worker and parameter devices will be the same (see mirrored variables below). For others they will be different. For example, tf.distribute.experimental.CentralStorageStrategy puts the variables on a single device (which may be a worker device or may be the CPU), and tf.distribute.experimental.ParameterServerStrategy puts the variables on separate machines called parameter servers (see below).
  • A replica is one copy of the model, running on one slice of the input data. Right now each replica is executed on its own worker device, but once we add support for model parallelism a replica may span multiple worker devices.
  • A host is the CPU device on a machine with worker devices, typically used for running input pipelines.
  • A worker is defined to be the physical machine(s) containing the physical devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A worker may contain one or more replicas, but contains at least one replica. Typically one worker will correspond to one machine, but in the case of very large models with model parallelism, one worker may span multiple machines. We typically run one input pipeline per worker, feeding all the replicas on that worker.
  • Synchronous, or more commonly sync, training is where the updates from each replica are aggregated together before updating the model variables. This is in contrast to asynchronous, or async training, where each replica updates the model variables independently. You may also have replicas partitioned into groups which are in sync within each group but async between groups.
  • Parameter servers: These are machines that hold a single copy of parameters/variables, used by some strategies (right now just tf.distribute.experimental.ParameterServerStrategy). All replicas that want to operate on a variable retrieve it at the beginning of a step and send an update to be applied at the end of the step. These can in priniciple support either sync or async training, but right now we only have support for async training with parameter servers. Compare to tf.distribute.experimental.CentralStorageStrategy, which puts all variables on a single device on the same machine (and does sync training), and tf.distribute.MirroredStrategy, which mirrors variables to multiple devices (see below).
  • Mirrored variables: These are variables that are copied to multiple devices, where we keep the copies in sync by applying the same updates to every copy. Normally would only be used with sync training.
  • Reductions and all-reduce: A reduction is some method of aggregating multiple values into one value, like "sum" or "mean". If a strategy is doing sync training, we will perform a reduction on the gradients to a parameter from all replicas before applying the update. All-reduce is an algorithm for performing a reduction on values from multiple devices and making the result available on all of those devices.

Note that we provide a default version of tf.distribute.Strategy that is used when no other strategy is in scope, that provides the same API with reasonable default behavior.


cluster_resolver module: Library imports for ClusterResolvers.

experimental module: Experimental Distribution Strategy library.


class CrossDeviceOps: Base class for cross-device reduction and broadcasting algorithms.

class DistributedValues: Base class for representing distributed values.

class HierarchicalCopyAllReduce: Reduction using hierarchical copy all-reduce.

class InputContext: A class wrapping information needed by an input function.

class InputReplicationMode: Replication mode for input function.

class MirroredStrategy: Synchronous training across multiple replicas on one machine.

class NcclAllReduce: Reduction using NCCL all-reduce.

class OneDeviceStrategy: A distribution strategy for running on a single device.

class ReduceOp: Indicates how a set of values should be reduced.

class ReductionToOneDevice: Always do reduction to one device first and then do broadcasting.

class ReplicaContext: tf.distribute.Strategy API when in a replica context.

class RunOptions: Run options for

class Server: An in-process TensorFlow server, for use in distributed training.

class Strategy: A state & compute distribution policy on a list of devices.

class StrategyExtended: Additional APIs for algorithms that need to be distribution-aware.


experimental_set_strategy(...): Set a tf.distribute.Strategy as current without with strategy.scope().

get_replica_context(...): Returns the current tf.distribute.ReplicaContext or None.

get_strategy(...): Returns the current tf.distribute.Strategy object.

has_strategy(...): Return if there is a current non-default tf.distribute.Strategy.

in_cross_replica_context(...): Returns True if in a cross-replica context.