tf.distribute.experimental.coordinator.ClusterCoordinator

An object to schedule and coordinate remote function execution.

This class is used to create fault-tolerant resources and dispatch functions to remote TensorFlow servers.

Currently, this class is not supported to be used in a standalone manner. It should be used in conjunction with a tf.distribute strategy that is designed to work with it. The ClusterCoordinator class currently only works tf.distribute.experimental.ParameterServerStrategy.

The schedule/join APIs

The most important APIs provided by this class is the schedule/join pair. The schedule API is non-blocking in that it queues a tf.function and returns a RemoteValue immediately. The queued functions will be dispatched to remote workers in background threads and their RemoteValues will be filled asynchronously. Since schedule doesn’t require worker assignment, the tf.function passed in can be executed on any available worker. If the worker it is executed on becomes unavailable before its completion, it will be migrated to another worker. Because of this fact and function execution is not atomic, a function may be executed more than once.

Handling Task Failure

This class when used with tf.distribute.experimental.ParameterServerStrategy, comes with built-in fault tolerance for worker failures. That is, when some workers are not available for any reason to be reached from the coordinator, the training progress continues to be made with the remaining workers. Upon recovery of a failed worker, it will be added for function execution after datasets created by create_per_worker_dataset are re-built on it.

When a parameter server fails, a tf.errors.UnavailableError is raised by schedule, join or done. In this case, in addition to bringing back the failed parameter server, users should restart the coordinator so that it reconnects to workers and parameter servers, re-creates the variables, and loads checkpoints. If the coordinator fails, after the user brings it back, the program will automatically connect to workers and parameter servers, and continue the progress from a checkpoint.

It is thus essential that in user's program, a checkpoint file is periodically saved, and restored at the start of the program. If an tf.keras.optimizers.Optimizer is checkpointed, after restoring from a checkpoiont, its iterations property roughly indicates the number of steps that have been made. This can be used to decide how many epochs and steps are needed before the training completion.

See tf.distribute.experimental.ParameterServerStrategy docstring for an example usage of this API.

This is currently under development, and the API as well as implementation are subject to changes.

strategy a supported tf.distribute.Strategy object. Currently, only tf.distribute.experimental.ParameterServerStrategy is supported.

ValueError if the strategy being used is not supported.

strategy Returns the Strategy associated with the ClusterCoordinator.

Methods

create_per_worker_dataset

View source

Create dataset on workers by calling dataset_fn on worker devices.

This creates the given dataset generated by dataset_fn on workers and returns an object that represents the collection of those individual datasets. Calling iter on such collection of datasets returns a tf.distribute.experimental.coordinator.PerWorkerValues, which is a collection of iterators, where the iterators have been placed on respective workers.

Calling next on a PerWorkerValues of iterator is unsupported. The iterator is meant to be passed as an argument into tf.distribute.experimental.coordinator.ClusterCoordinator.schedule. When the scheduled function is about to be executed by a worker, the function will receive the individual iterator that corresponds to the worker. The next method can be called on an iterator inside a scheduled function when the iterator is an input of the function.

Currently the schedule method assumes workers are all the same and thus assumes the datasets on different workers are the same, except they may be shuffled differently if they contain a dataset.shuffle operation and a random seed is not set. Because of this, we also recommend the datasets to be repeated indefinitely and schedule a finite number of steps instead of relying on the OutOfRangeError from a dataset.

Example:

strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
    strategy=strategy)

@tf.function
def worker_fn(iterator):
  return next(iterator)

def per_worker_dataset_fn():
  return strategy.distribute_datasets_from_function(
      lambda x: tf.data.Dataset.from_tensor_slices([3] * 3))

per_worker_dataset = coordinator.create_per_worker_dataset(
    per_worker_dataset_fn)
per_worker_iter = iter(per_worker_dataset)
remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
assert remote_value.fetch() == 3

Args
dataset_fn The dataset function that returns a dataset. This is to be executed on the workers.

Returns
An object that represents the collection of those individual datasets. iter is expected to be called on this object that returns a tf.distribute.experimental.coordinator.PerWorkerValues of the iterators (that are on the workers).

done

View source

Returns whether all the scheduled functions have finished execution.

If any previously scheduled function raises an error, done will fail by raising any one of those errors.

When done returns True or raises, it guarantees that there is no function that is still being executed.

Returns
Whether all the scheduled functions have finished execution.

Raises
Exception one of the exceptions caught by the coordinator by any previously scheduled function since the last time an error was thrown or since the beginning of the program.

fetch

View source

Blocking call to fetch results from the remote values.

This is a wrapper around tf.distribute.experimental.coordinator.RemoteValue.fetch for a RemoteValue structure; it returns the execution results of RemoteValues. If not ready, wait for them while blocking the caller.

Example:

strategy = ...
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
    strategy)

def dataset_fn():
  return tf.data.Dataset.from_tensor_slices([1, 1, 1])

with strategy.scope():
  v = tf.Variable(initial_value=0)

@tf.function
def worker_fn(iterator):
  def replica_fn(x):
    v.assign_add(x)
    return v.read_value()
  return strategy.run(replica_fn, args=(next(iterator),))

distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
distributed_iterator = iter(distributed_dataset)
result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
assert coordinator.fetch(result) == 1

Args
val The value to fetch the results from. If this is structure of tf.distribute.experimental.coordinator.RemoteValue, fetch() will be called on the individual tf.distribute.experimental.coordinator.RemoteValue to get the result.

Returns
If val is a tf.distribute.experimental.coordinator.RemoteValue or a structure of tf.distribute.experimental.coordinator.RemoteValues, return the fetched tf.distribute.experimental.coordinator.RemoteValue values immediately if they are available, or block the call until they are available, and return the fetched tf.distribute.experimental.coordinator.RemoteValue values with the same structure. If val is other types, return it as-is.

join

View source

Blocks until all the scheduled functions have finished execution.

If any previously scheduled function raises an error, join will fail by raising any one of those errors, and clear the errors collected so far. If this happens, some of the previously scheduled functions may have not been executed. Users can call fetch on the returned tf.distribute.experimental.coordinator.RemoteValue to inspect if they have executed, failed, or cancelled. If some that have been cancelled need to be rescheduled, users should call schedule with the function again.

When join returns or raises, it guarantees that there is no function that is still being executed.

Raises
Exception one of the exceptions caught by the coordinator by any previously scheduled function since the last time an error was thrown or since the beginning of the program.

schedule

View source

Schedules fn to be dispatched to a worker for asynchronous execution.

This method is non-blocking in that it queues the fn which will be executed later and returns a tf.distribute.experimental.coordinator.RemoteValue object immediately. fetch can be called on it to wait for the function execution to finish and retrieve its output from a remote worker. On the other hand, call tf.distribute.experimental.coordinator.ClusterCoordinator.join to wait for all scheduled functions to finish.

schedule guarantees that fn will be executed on a worker at least once; it could be more than once if its corresponding worker fails in the middle of its execution. Note that since worker can fail at any point when executing the function, it is possible that the function is partially executed, but tf.distribute.experimental.coordinator.ClusterCoordinator guarantees that in those events, the function will eventually be executed on any worker that is available.

If any previously scheduled function raises an error, schedule will raise any one of those errors, and clear the errors collected so far. What happens here, some of the previously scheduled functions may have not been executed. User can call fetch on the returned tf.distribute.experimental.coordinator.RemoteValue to inspect if they have executed, failed, or cancelled, and reschedule the corresponding function if needed.

When schedule raises, it guarantees that there is no function that is still being executed.

At this time, there is no support of worker assignment for function execution, or priority of the workers.

args and kwargs are the arguments passed into fn, when fn is executed on a worker. They can be tf.distribute.experimental.coordinator.PerWorkerValues and in this case, the argument will be substituted with the corresponding component on the target worker. Arguments that are not tf.distribute.experimental.coordinator.PerWorkerValues will be passed into fn as-is. Currently, tf.distribute.experimental.coordinator.RemoteValue is not supported to be input args or kwargs.

Args
fn A tf.function; the function to be dispatched to a worker for execution asynchronously. Regular python function is not supported to be scheduled.
args Positional arguments for fn.
kwargs Keyword arguments for fn.

Returns
A tf.distribute.experimental.coordinator.RemoteValue object that represents the output of the function scheduled.

Raises
Exception one of the exceptions caught by the coordinator from any previously scheduled function, since the last time an error was thrown or since the beginning of the program.