|View source on GitHub|
Additional APIs for algorithms that need to be distribution-aware.
tf.compat.v1.distribute.StrategyExtended( container_strategy )
Some common use cases of functions on this page:
tf.distribute.DistributedValues can have the same locality as a
distributed variable, which leads to a mirrored value residing on the same
devices as the variable (as opposed to the compute devices). Such values may
be passed to a call to
tf.distribute.StrategyExtended.update to update the
value of a variable. You may use
tf.distribute.StrategyExtended.colocate_vars_with to give a variable the
same locality as another variable. You may convert a "PerReplica" value to a
variable's locality by using
- How to update a distributed variable
A distributed variable is variables created on multiple devices. As discussed in the glossary, mirrored variable and SyncOnRead variable are two examples. The standard pattern for updating distributed variables is to:
- In your function passed to
tf.distribute.Strategy.run, compute a list of (update, variable) pairs. For example, the update might be a gradient of the loss with respect to the variable.
- Switch to cross-replica mode by calling
tf.distribute.get_replica_context().merge_call()with the updates and variables as arguments.
tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)(for one variable) or
tf.distribute.StrategyExtended.batch_reduce_to(for a list of variables) to sum the updates.
tf.distribute.StrategyExtended.update(v)for each variable to update its value.
In fact, a higher-level solution to update a distributed variable is by
assign on the variable as you would do to a regular
You can call the method in both replica context and cross-replica context.
For a mirrored variable, calling
assign in replica context requires you
to specify the
aggregation type in the variable constructor. In that case,
the context switching and sync described in steps 2 through 4 are handled for
you. If you call
assign on mirrored variable in cross-replica context,
you can only assign a single value or assign values from another mirrored
variable or a mirrored
tf.distribute.DistributedValues. For a SyncOnRead
variable, in replica context, you can simply call
assign on it and no
aggregation happens under the hood. In cross-replica context, you can only
assign a single value to a SyncOnRead variable. One example case is restoring
from a checkpoint: if the
aggregation type of the variable is
tf.VariableAggregation.SUM, it is assumed that replica values were added
before checkpointing, so at the time of restoring, the value is divided by
the number of replicas and then assigned to each replica; if the
tf.VariableAggregation.MEAN, the value is assigned to each replica
Whether the strategy uses between-graph replication or not.
This is expected to return a constant value that will not be changed throughout its life cycle.
||Whether initialization is needed.|
||Returns the tuple of all devices used to place variables.|
||Whether checkpointing is needed.|
||Whether saving summaries is needed.|
||Returns the tuple of all devices used to for compute replica execution.|
batch_reduce_to( reduce_op, value_destination_pairs, experimental_hints=None )
reduce_to calls into one for faster execution.
Reduction type, an instance of
A sequence of (value, destinations) pairs. See
A list of mirrored values, one per pair in
broadcast_to( tensor, destinations )
Mirror a tensor on one device to all worker devices.
||A Tensor value to broadcast.|
A mirrored variable or device string specifying the
destination devices to copy
A value mirrored to
call_for_each_replica( fn, args=(), kwargs=None )
fn once per replica.
fn may call
tf.get_replica_context() to access methods such as
merge_call() is used to communicate between the replicas and
re-enter the cross-replica context. All replicas pause their execution
having encountered a
merge_call() call. After that the
merge_fn-function is executed. Its results are then unwrapped and
given back to each replica call. After that execution resumes until
fn is complete or encounters another
# Called once in "cross-replica" context. def merge_fn(distribution, three_plus_replica_id): # sum the values across replicas return sum(distribution.experimental_local_results(three_plus_replica_id)) # Called once per replica in `distribution`, in a "replica" context. def fn(three): replica_ctx = tf.get_replica_context() v = three + replica_ctx.replica_id_in_sync_group # Computes the sum of the `v` values across all replicas. s = replica_ctx.merge_call(merge_fn, args=(v,)) return s + v with distribution.scope(): # in "cross-replica" context ... merged_results = distribution.run(fn, args=) # merged_results has the values from every replica execution of `fn`. # This statement prints a list: print(distribution.experimental_local_results(merged_results))
||function to run (will be run once per replica).|
Tuple or list with positional arguments for
Dict with keyword arguments for
Merged return value of
colocate_vars_with( colocate_with_variable )
Scope that controls which devices variables will be created on.
No operations should be added to the graph inside this scope, it should only be used when creating variables (some implementations work by changing variable creation, others work by using a tf.compat.v1.colocate_with() scope).
This may only be used inside
with strategy.scope(): var1 = tf.Variable(...) with strategy.extended.colocate_vars_with(var1): # var2 and var3 will be created on the same device(s) as var1 var2 = tf.Variable(...) var3 = tf.Variable(...) def fn(v1, v2, v3): # operates on v1 from var1, v2 from var2, and v3 from var3 # `fn` runs on every device `var1` is on, `var2` and `var3` will be there # too. strategy.extended.update(var1, fn, args=(var2, var3))
A variable created in this strategy's
|A context manager.|
experimental_make_numpy_dataset( numpy_input, session=Non