Watch keynotes, product sessions, workshops, and more from Google I/O See playlist


An multi-worker tf.distribute strategy with parameter servers.

Inherits From: Strategy

Used in the notebooks

Used in the tutorials

Parameter server training is a common data-parallel method to scale up a machine learning model on multiple machines. A parameter server training cluster consists of workers and parameter servers. Variables are created on parameter servers and they are read and updated by workers in each step. By default, workers read and update these variables independently without synchronizing with each other. Under this configuration, it is known as asynchronous training.

In TensorFlow 2, we recommend an architecture based on central coordination for parameter server training. Each worker and parameter server runs a tf.distribute.Server, and on top of that, a coordinator task is responsible for creating resources on workers and parameter servers, dispatching functions, and coordinating the training. The coordinator uses a tf.distribute.experimental.coordinator.ClusterCoordinator to coordinate the cluster, and a tf.distribute.experimental.ParameterServerStrategy to define variables on parameter servers and computation on workers.

For the training to work, the coordinator dispatches tf.functions to be executed on remote workers. Upon receiving requests from the coordinator, a worker executes the tf.function by reading the variables from parameter servers, executing the ops, and updating the variables on the parameter servers. Each of the worker only processes the requests from the coordinator, and communicates with parameter servers, without direct interactions with other workers in the cluster.

As a result, failures of some workers do not prevent the cluster from continuing the work, and this allows the cluster to train with instances that can be occasionally unavailable (e.g. preemptible or spot instances). The coordinator and parameter servers though, must be available at all times for the cluster to make progress.

Note that the coordinator is not one of the training workers. Instead, it creates resources such as variables and datasets, dispatchs tf.functions, saves checkpoints and so on. In addition to workers, parameter servers and the coordinator, an optional evaluator can be run on the side that periodically reads the checkpoints saved by the coordinator and runs evaluations against each checkpoint.

ParameterServerStrategy is supported with two training APIs: Custom Training Loop (CTL) and Keras Training API, also known as CTL is recommended when users prefer to define the details of their training loop, and is recommended when users prefer a high-level abstraction and handling of training.

When using a CTL, ParameterServerStrategy has to work in conjunction with a tf.distribute.experimental.coordinator.ClusterCoordinator object.

When using, currently only the tf.keras.utils.experimental.DatasetCreator input type is supported.

Example code for coordinator

This section provides code snippets that are intended to be run on (the only) one task that is designated as the coordinator. Note that cluster_resolver, variable_partitioner, and dataset_fn arguments are explained in the following "Cluster setup", "Variable partitioning", and "Dataset preparation" sections.

With a CTL,

# Prepare a strategy to use with the cluster and variable partitioning info.
strategy = tf.distribute.experimental.ParameterServerStrategy(
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(

# Prepare a distribute dataset that will place datasets on the workers.
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...)

with strategy.scope():
  model = ...
  optimizer, metrics = ...  # Keras optimizer/metrics are great choices
  checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
  checkpoint_manager = tf.train.CheckpointManager(
      checkpoint, checkpoint_dir, max_to_keep=2)
  # `load_checkpoint` infers initial epoch from `optimizer.iterations`.
  initial_epoch = load_checkpoint(checkpoint_manager) or 0

def worker_fn(iterator):

  def replica_fn(inputs):
    batch_data, labels = inputs
    # calculate gradient, applying gradient, metrics update etc., args=(next(iterator),))

for epoch in range(initial_epoch, num_epoch):
  distributed_iterator = iter(distributed_dataset)  # Reset iterator state.
  for step in range(steps_per_epoch):

    # Asynchronously schedule the `worker_fn` to be executed on an arbitrary
    # worker. This call returns immediately.
    coordinator.schedule(worker_fn, args=(distributed_iterator,))

  # `join` blocks until all scheduled `worker_fn`s finish execution. Once it
  # returns, we can read the metrics and save checkpoints as needed.
  coordinator.join()'Metric result: %r', metrics.result())


# Prepare a strategy to use with the cluster and variable partitioning info.
strategy = tf.distribute.experimental.ParameterServerStrategy(

# A dataset function takes a `input_context` and returns a `Dataset`
def dataset_fn(input_context):
  dataset =
  return dataset.repeat().shard(...).batch(...).prefetch(...)

# With ``, a `DatasetCreator` needs to be used.
input = tf.keras.utils.experimental.DatasetCreator(dataset_fn=...)

with strategy.scope():
  model = ...  # Make sure the `Model` is created within scope.
model.compile(optimizer="rmsprop", loss="mse", steps_per_execution=..., ...)

# Optional callbacks to checkpoint the model, back up the progress, etc.
callbacks = [tf.keras.callbacks.ModelCheckpoint(...), ...]

# `steps_per_epoch` is required with `ParameterServerStrategy`., epochs=..., steps_per_epoch=..., callbacks=callbacks)

Example code for worker and parameter servers

In addition to the coordinator, there should be tasks designated as "worker" or "ps". They should run the following code to start a TensorFlow server, waiting for coordinator's requests:

# Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves
# the cluster information. See below "Cluster setup" section.
cluster_resolver = ...

server = tf.distribute.Server(

# Blocking the process that starts a server from exiting.

Cluster setup

In order for the tasks in the cluster to know other tasks' addresses, a tf.distribute.cluster_resolver.ClusterResolver is required to be used in coordinator, worker, and ps. The tf.distribute.cluster_resolver.ClusterResolver is responsible for providing the cluster information, as well as the task type and id of the current task. See tf.distribute.cluster_resolver.ClusterResolver for more information.

If TF_CONFIG environment variable is set, a tf.distribute.cluster_resolver.TFConfigClusterResolver should be used as well.

Since there are assumptions in tf.distribute.experimental.ParameterServerStrategy around the naming of the task types, "chief", "ps", and "worker" should be used in the tf.distribute.cluster_resolver.ClusterResolver to refer to the coordinator, parameter servers, and workers, respectively.

The following example demonstrates setting TF_CONFIG for the task designated as a parameter server (task type "ps") and index 1 (the second task), in a cluster with 1 chief, 2 parameter servers, and 3 workers. Note that it needs to be set before the use of tf.distribute.cluster_resolver.TFConfigClusterResolver.

Example code for cluster setup:

os.environ['TF_CONFIG'] = '''
  "cluster": {
    "chief": [""],
    "ps": ["", ""],
    "worker": ["", "",
  "task": {
    "type": "ps",
    "index": 1

If you prefer to run the same binary for all tasks, you will need to let the binary branch into different roles at the beginning of the program:

# If coordinator, create a strategy and start the training program.
if cluster_resolver.task_type == 'chief':
  strategy = tf.distribute.experimental.ParameterServerStrategy(

# If worker/ps, create a server
elif cluster_resolver.task_type in ("worker", "ps"):
  server = tf.distribute.Server(...)

Alternatively, you can also start a bunch of TensorFlow servers in advance and connect to them later. The coordinator can be in the same cluster or on any machine that has connectivity to workers and parameter servers. This is covered in our guide and tutorial.

Variable creation with strategy.scope()

tf.distribute.experimental.ParameterServerStrategy follows the tf.distribute API contract where variable creation is expected to be inside the context manager returned by strategy.scope(), in order to be correctly placed on parameter servers in a round-robin manner:

# In this example, we're assuming having 3 ps.
strategy = tf.distribute.experimental.ParameterServerStrategy(
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(

# Variables should be created inside scope to be placed on parameter servers.
# If created outside scope such as `v1` here, it would be placed on the
# coordinator.
v1 = tf.Variable(initial_value=0.0)

with strategy.scope():
  v2 = tf.Variable(initial_value=1.0)
  v3 = tf.Variable(initial_value=2.0)
  v4 = tf.Variable(initial_value=3.0)
  v5 = tf.Variable(initial_value=4.0)

# v2 through v5 are created in scope and are distributed on parameter servers.
# Default placement is round-robin but the order should not be relied on.
assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0"
assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0"
assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0"
assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0"

See distribute.Strategy.scope for more information.

Variable partitioning

Having dedicated servers to store variables means being able to divide up, or "shard" the variables across the ps. Partitioning large variable among ps is a commonly used technique to boost training throughput and mitigate memory constraints. It enables parallel computations and updates on different shards of a variable, and often yields better load balancing across parameter servers. Without sharding, models with large variables (e.g, embeddings) that can't fit into one machine's memory would otherwise be unable to train.

With tf.distribute.experimental.ParameterServerStrategy, if a variable_partitioner is provided to __init__ and certain conditions are satisfied, the resulting variables created in scope are sharded across the parameter servers, in a round-robin fashion. The variable reference returned from tf.Variable becomes a type that serves as the container of the sharded variables. One can access variables attribute of this container for the actual variable components. If building model with tf.Module or Keras, the variable components are collected in the variables alike attributes.

class Dense(tf.Module):
  def __init__(self, name=None):
    self.w = tf.Variable(tf.random.normal([100, 10]), name='w')

  def __call__(self, x):
    return x * self.w

# Partition the dense layer into 2 shards.
variable_partitioner = (
    num_shards = 2))
strategy = tf.distribute.experimental.ParameterServerStrategy(
  variable_partitioner = variable_partitioner)
with strategy.scope():
  dense = Dense()
assert len(dense.variables) == 2
assert isinstance(dense.variables[0], tf.Variable)
assert isinstance(dense.variables[1], tf.Variable)
assert dense.variables[0].shape == (50, 10)
assert dense.variables[1].shape == (50, 10)

The sharded variable container can be converted to a Tensor via tf.convert_to_tensor. This means the container can be directly used in most Python Ops where such Tensor conversion automatically happens. For example, in the above code snippet, x * self.w would implicitly apply the said tensor conversion. Note that such conversion can be expensive, as the variable components need to be transferred from multiple parameter servers to where the value is used.

tf.nn.embedding_lookup on the other hand doesn't apply the tensor conversion, and performs parallel lookups on the variable components instead. This is crucial to scale up embedding lookups when the embedding table variable is large.

When a partitioned variable is saved to a SavedModel, it will be saved as if it is one single variable. This improves serving efficiency by eliminating a number of Ops that handle the partiton aspects.

Known limitations of variable partitioning:

  • Number of partitions must not change across Checkpoint saving/loading.

  • After saving partitioned variables to a SavedModel, the SavedModel can't be loaded via tf.saved_model.load.

  • Partition variable doesn't directly work with tf.GradientTape, please use the variables attributes to get the actual variable components and use them in gradient APIs instead.

Dataset preparation

With tf.distribute.experimental.ParameterServerStrategy, a dataset is created in each of the workers to be used for training. This is done by creating a dataset_fn that takes no argument and returns a, and passing the dataset_fn into tf.distribute.experimental.coordinator. ClusterCoordinator.create_per_worker_dataset. We recommend the dataset to be shuffled and repeated to have the examples run through the training as evenly as possible.

def dataset_fn():
  filenames = ...
  dataset =

  # Dataset is recommended to be shuffled, and repeated.
  return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...)

coordinator =
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)


cluster_resolver a tf.distribute.cluster_resolver.ClusterResolver object.
variable_partitioner a distribute.experimental.partitioners.Partitioner that specifies how to partition variables. If None, variables will not be partitioned.