tf.distribute.StrategyExtended

TensorFlow 1 version View source on GitHub

Additional APIs for algorithms that need to be distribution-aware.

Some common use cases of functions on this page:

  • Locality

tf.distribute.DistributedValues can have the same locality as a distributed variable, which leads to a mirrored value residing on the same devices as the variable (as opposed to the compute devices). Such values may be passed to a call to tf.distribute.StrategyExtended.update to update the value of a variable. You may use tf.distribute.StrategyExtended.colocate_vars_with to give a variable the same locality as another variable. You may convert a "PerReplica" value to a variable's locality by using tf.distribute.StrategyExtended.reduce_to or tf.distribute.StrategyExtended.batch_reduce_to.

  • How to update a distributed variable

A distributed variable is variables created on multiple devices. As discussed in the glossary, mirrored variable and SyncOnRead variable are two examples. The standard pattern for updating distributed variables is to:

  1. In your function passed to tf.distribute.Strategy.run, compute a list of (update, variable) pairs. For example, the update might be a gradient of the loss with respect to the variable.
  2. Switch to cross-replica mode by calling tf.distribute.get_replica_context().merge_call() with the updates and variables as arguments.
  3. Call tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v) (for one variable) or tf.distribute.StrategyExtended.batch_reduce_to (for a list of variables) to sum the updates.
  4. Call tf.distribute.StrategyExtended.update(v) for each variable to update its value.

Steps 2 through 4 are done automatically by class tf.keras.optimizers.Optimizer if you call its tf.keras.optimizers.Optimizer.apply_gradients method in a replica context.

In fact, a higher-level solution to update a distributed variable is by calling assign on the variable as you would do to a regular tf.Variable. You can call the method in both replica context and cross-replica context. For a mirrored variable, calling assign in replica context requires you to specify the aggregation type in the variable constructor. In that case, the context switching and sync described in steps 2 through 4 are handled for you. If you call assign on mirrored variable in cross-replica context, you can only assign a single value or assign values from another mirrored variable or a mirrored tf.distribute.DistributedValues. For a SyncOnRead variable, in replica context, you can simply call assign on it and no aggregation happens under the hood. In cross-replica context, you can only assign a single value to a SyncOnRead variable. One example case is restoring from a checkpoint: if the aggregation type of the variable is tf.VariableAggregation.SUM, it is assumed that replica values were added before checkpointing, so at the time of restoring, the value is divided by the number of replicas and then assigned to each replica; if the aggregation type is tf.VariableAggregation.MEAN, the value is assigned to each replica directly.

experimental_require_static_shapes Returns True if static shape is required; False otherwise.
parameter_devices Returns the tuple of all devices used to place variables.
worker_devices Returns the tuple of all devices used to for compute replica execution.

Methods

batch_reduce_to

View source

Combine multiple reduce_to calls into one for faster execution.

Args
reduce_op Reduction type, an instance of tf.distribute.ReduceOp enum.
value_destination_pairs A sequence of (value, destinations) pairs. See reduce_to() for a description.
experimental_hints A tf.distrbute.experimental.CollectiveHints. Hints to perform collective operations.

Returns
A list of mirrored values, one per pair in value_destination_pairs.

colocate_vars_with

View source

Scope that controls which devices variables will be created on.

No operations should be added to the graph inside this scope, it should only be used when creating variables (some implementations work by changing variable creation, others work by using a tf.compat.v1.colocate_with() scope).

This may only be used inside self.scope().

Example usage:

with strategy.scope():
  var1 = tf.Variable(...)
  with strategy.extended.colocate_vars_with(var1):
    # var2 and var3 will be created on the same device(s) as var1
    var2 = tf.Variable(...)
    var3 = tf.Variable(...)

  def fn(v1, v2, v3):
    # operates on v1 from var1, v2 from var2, and v3 from var3

  # `fn` runs on every device `var1` is on, `var2` and `var3` will be there
  # too.
  strategy.extended.update(var1, fn, args=(var2, var3))

Args
colocate_with_variable A variable created in this strategy's scope(). Variables created while in the returned context manager will be on the same set of devices as colocate_with_variable.

Returns
A context manager.

reduce_to

View source

Combine (via e.g. sum or mean) values across replicas.

Args
reduce_op Reduction type, an instance of tf.distribute.ReduceOp enum.
value A per-replica value with one value per replica.
destinations A mirrored variable, a per-replica tensor, or a device string. The return value will be copied to all destination devices (or all the devices where the destinations value resides). To perform an all-reduction, pass value to destinations.
experimental_hints A tf.distrbute.experimental.CollectiveHints. Hints to perform collective operations.

Returns
A tensor or value mirrored to destinations.

update

View source

Run fn to update var using inputs mirrored to the same devices.

tf.distribute.StrategyExtended.update takes a distributed variable var to be updated, an update function fn, and args and kwargs for fn. It applies fn to each component variable of var and passes corresponding values from args and kwargs. Neither args nor kwargs may contain per-replica values. If they contain mirrored values, they will be unwrapped before calling fn. For example, fn can be assign_add and args can be a mirrored DistributedValues where each component contains the value to be added to this mirrored variable var. Calling update will call assign_add on each component variable of var with the corresponding tensor value on that device.

Example usage:

strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1']) # With 2 devices
with strategy.scope():
  v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
def update_fn(v):
  return v.assign(1.0)
result = strategy.extended.update(v, update_fn)
# result is
# Mirrored:{
#  0: tf.Tensor(1.0, shape=(), dtype=float32),
#  1: tf.Tensor(1.0, shape=(), dtype=float32)
# }

If var is mirrored across multiple devices, then this method implements logic as following:

results = {}
for device, v in var:
  with tf.device(device):
    # args and kwargs will be unwrapped if they are mirrored.
    results[device] = fn(v, *args, **kwargs)
return merged(results)

Otherwise, this method returns fn(var, *args, **kwargs) colocated with var.

Args
var Variable, possibly mirrored to multiple devices, to operate on.
fn Function to call. Should take the variable as the first argument.
args Tuple or list. Additional positional arguments to pass to fn().
kwargs Dict with keyword arguments to pass to fn().
group Boolean. Defaults to True. If False, the return value will be unwrapped.

Returns
By default, the merged return value of fn across all replicas. The merged result has dependencies to make sure that if it is evaluated at all, the side effects (updates) will happen on every replica. If instead "group=False" is specified, this function will return a nest of lists where each list has an element per replica, and the caller is responsible for ensuring all elements are executed.

value_container

View source

Returns the container that this per-replica value belongs to.

Args
value A value returned by run() or a variable created in scope().

Returns
A container that value belongs to. If value does not belong to any container (including the case of container having been destroyed), returns the value itself. value in experimental_local_results(value_container(value)) will always be true.

variable_created_in_scope

View source

Tests whether v was created while this strategy scope was active.

Variables created inside the strategy scope are "owned" by it:

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
  v = tf.Variable(1.)
strategy.extended.variable_created_in_scope(v)
True

Variables created outside the strategy are not owned by it:

strategy = tf.distribute.MirroredStrategy()
v = tf.Variable(1.)
strategy.extended.variable_created_in_scope(v)
False

Args
v A tf.Variable instance.

Returns
True if v was created inside the scope, False if not.