ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf.distribute.StrategyExtended

Additional APIs for algorithms that need to be distribution-aware.

Some common use cases of functions on this page:

  • Locality

tf.distribute.DistributedValues can have the same locality as a distributed variable, which leads to a mirrored value residing on the same devices as the variable (as opposed to the compute devices). Such values may be passed to a call to tf.distribute.StrategyExtended.update to update the value of a variable. You may use tf.distribute.StrategyExtended.colocate_vars_with to give a variable the same locality as another variable. You may convert a "PerReplica" value to a variable's locality by using tf.distribute.StrategyExtended.reduce_to or tf.distribute.StrategyExtended.batch_reduce_to.

  • How to update a distributed variable

A distributed variable is variables created on multiple devices. As discussed in the glossary, mirrored variable and SyncOnRead variable are two examples. The standard pattern for updating distributed variables is to:

  1. In your function passed to tf.distribute.Strategy.run, compute a list of (update, variable) pairs. For example, the update might be a gradient of the loss with respect to the variable.
  2. Switch to cross-replica mode by calling tf.distribute.get_replica_context().merge_call() with the updates and variables as arguments.
  3. Call tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v) (for one variable) or tf.distribute.StrategyExtended.batch_reduce_to (for a list of variables) to sum the updates.
  4. Call tf.distribute.StrategyExtended.update(v) for each variable to update its value.

Steps 2 through 4 are done automatically by class tf.keras.optimizers.Optimizer if you call its tf.keras.optimizers.Optimizer.apply_gradients method in a replica context.

In fact, a higher-level solution to update a distributed variable is by calling assign on the variable as you would do to a regular tf.Variable. You can call the method in both replica context and cross-replica context. For a mirrored variable, calling assign in replica context requires you to specify the aggregation type in the variable constructor. In that case, the context switching and sync described in steps 2 through 4 are handled for you. If you call assign on mirrored variable in cross-replica context, you can only assign a single value or assign values from another mirrored variable or a mirrored tf.distribute.DistributedValues. For a SyncOnRead variable, in replica context, you can simply call assign on it and no aggregation happens under the hood. In cross-replica context, you can only assign a single value to a SyncOnRead variable. One example case is restoring from a checkpoint: if the aggregation type of the variable is tf.VariableAggregation.SUM, it is assumed that replica values were added before checkpointing, so at the time of restoring, the value is divided by the number of replicas and then assigned to each replica; if the aggregation type is tf.VariableAggregation.MEAN, the value is assigned to each replica directly.

experimental_require_static_shapes Returns True if static shape is required; False otherwise.
parameter_devices Returns the tuple of all devices used to place variables.
worker_devices Returns the tuple of all devices used to for compute replica execution.

Methods

batch_reduce_to

View source

Combine multiple reduce_to calls into one for faster execution.

Similar to reduce_to, but accepts a list of (value, destinations) pairs.