|TensorFlow 1 version||View source on GitHub|
Additional APIs for algorithms that need to be distribution-aware.
tf.distribute.StrategyExtended( container_strategy )
Some common use cases of functions on this page:
tf.distribute.DistributedValues can have the same locality as a
distributed variable, which leads to a mirrored value residing on the same
devices as the variable (as opposed to the compute devices). Such values may
be passed to a call to
tf.distribute.StrategyExtended.update to update the
value of a variable. You may use
tf.distribute.StrategyExtended.colocate_vars_with to give a variable the
same locality as another variable. You may convert a "PerReplica" value to a
variable's locality by using
- How to update a distributed variable
A distributed variable is variables created on multiple devices. As discussed in the glossary, mirrored variable and SyncOnRead variable are two examples. The standard pattern for updating distributed variables is to:
- In your function passed to
tf.distribute.Strategy.run, compute a list of (update, variable) pairs. For example, the update might be a gradient of the loss with respect to the variable.
- Switch to cross-replica mode by calling
tf.distribute.get_replica_context().merge_call()with the updates and variables as arguments.
tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)(for one variable) or
tf.distribute.StrategyExtended.batch_reduce_to(for a list of variables) to sum the updates.
tf.distribute.StrategyExtended.update(v)for each variable to update its value.
In fact, a higher-level solution to update a distributed variable is by
assign on the variable as you would do to a regular
You can call the method in both replica context and cross-replica context.
For a mirrored variable, calling
assign in replica context requires you
to specify the
aggregation type in the variable constructor. In that case,
the context switching and sync described in steps 2 through 4 are handled for
you. If you call
assign on mirrored variable in cross-replica context,
you can only assign a single value or assign values from another mirrored
variable or a mirrored
tf.distribute.DistributedValues. For a SyncOnRead
variable, in replica context, you can simply call
assign on it and no
aggregation happens under the hood. In cross-replica context, you can only
assign a single value to a SyncOnRead variable. One example case is restoring
from a checkpoint: if the