|TensorFlow 1 version||View source on GitHub|
Additional APIs for algorithms that need to be distribution-aware.
tf.distribute.StrategyExtended( container_strategy )
Some common use cases of functions on this page:
tf.distribute.DistributedValues can have the same locality as a
distributed variable, which leads to a mirrored value residing on the same
devices as the variable (as opposed to the compute devices). Such values may
be passed to a call to
tf.distribute.StrategyExtended.update to update the
value of a variable. You may use
tf.distribute.StrategyExtended.colocate_vars_with to give a variable the
same locality as another variable. You may convert a "PerReplica" value to a
variable's locality by using
- How to update a distributed variable
A distributed variable is variables created on multiple devices. As discussed in the glossary, mirrored variable and SyncOnRead variable are two examples. The standard pattern for updating distributed variables is to:
- In your function passed to
tf.distribute.Strategy.run, compute a list of (update, variable) pairs. For example, the update might be a gradient of the loss with respect to the variable.
- Switch to cross-replica mode by calling
tf.distribute.get_replica_context().merge_call()with the updates and variables as arguments.
tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)(for one variable) or
tf.distribute.StrategyExtended.batch_reduce_to(for a list of variables) to sum the updates.
tf.distribute.StrategyExtended.update(v)for each variable to update its value.
In fact, a higher-level solution to update a distributed variable is by
assign on the variable as you would do to a regular
You can call the method in both replica context and cross-replica context.
For a mirrored variable, calling
assign in replica context requires you
to specify the
aggregation type in the variable constructor. In that case,
the context switching and sync described in steps 2 through 4 are handled for
you. If you call
assign on mirrored variable in cross-replica context,
you can only assign a single value or assign values from another mirrored
variable or a mirrored
tf.distribute.DistributedValues. For a SyncOnRead
variable, in replica context, you can simply call
assign on it and no
aggregation happens under the hood. In cross-replica context, you can only
assign a single value to a SyncOnRead variable. One example case is restoring
from a checkpoint: if the
aggregation type of the variable is
tf.VariableAggregation.SUM, it is assumed that replica values were added
before checkpointing, so at the time of restoring, the value is divided by
the number of replicas and then assigned to each replica; if the
tf.VariableAggregation.MEAN, the value is assigned to each replica
||Returns the tuple of all devices used to place variables.|
||Returns the tuple of all devices used to for compute replica execution.|
batch_reduce_to( reduce_op, value_destination_pairs, experimental_hints=None )
reduce_to calls into one for faster execution.
Reduction type, an instance of
A sequence of (value, destinations) pairs. See
A list of mirrored values, one per pair in
colocate_vars_with( colocate_with_variable )
Scope that controls which devices variables will be created on.
No operations should be added to the graph inside this scope, it should only be used when creating variables (some implementations work by changing variable creation, others work by using a tf.compat.v1.colocate_with() scope).
This may only be used inside
with strategy.scope(): var1 = tf.Variable(...) with strategy.extended.colocate_vars_with(var1): # var2 and var3 will be created on the same device(s) as var1 var2 = tf.Variable(...) var3 = tf.Variable(...) def fn(v1, v2, v3): # operates on v1 from var1, v2 from var2, and v3 from var3 # `fn` runs on every device `var1` is on, `var2` and `var3` will be there # too. strategy.extended.update(var1, fn, args=(var2, var3))
A variable created in this strategy's
|A context manager.|
reduce_to( reduce_op, value, destinations, experimental_hints=None )
Combine (via e.g. sum or mean) values across replicas.
Reduction type, an instance of
||A per-replica value with one value per replica.|
A mirrored variable, a per-replica tensor, or a device
string. The return value will be copied to all destination devices (or
all the devices where the
A tensor or value mirrored to
update( var, fn, args=(), kwargs=None, group=True )
fn to update
var using inputs mirrored to the same devices.
tf.distribute.StrategyExtended.update takes a distributed variable
to be updated, an update function
fn to each component variable of
var and passes corresponding
kwargs may contain
per-replica values. If they contain mirrored values, they will be unwrapped
fn. For example,
fn can be
args can be
a mirrored DistributedValues where each component contains the value to be
added to this mirrored variable
update will call
assign_add on each componen