tf.distribute.experimental.ValueContext
A class wrapping information needed by a distribute function.
tf.distribute.experimental.ValueContext(
replica_id_in_sync_group=0, num_replicas_in_sync=1
)
This is a context class that is passed to the value_fn
in
strategy.experimental_distribute_values_from_function
and contains
information about the compute replicas. The num_replicas_in_sync
and
replica_id
can be used to customize the value on each replica.
Example usage:
- Directly constructed.
def value_fn(context):
return context.replica_id_in_sync_group/context.num_replicas_in_sync
context = tf.distribute.experimental.ValueContext(
replica_id_in_sync_group=2, num_replicas_in_sync=4)
per_replica_value = value_fn(context)
per_replica_value
0.5
- Passed in by
experimental_distribute_values_from_function
.
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
def value_fn(value_context):
return value_context.num_replicas_in_sync
distributed_values = (
strategy.experimental_distribute_values_from_function(
value_fn))
local_result = strategy.experimental_local_results(distributed_values)
local_result
(2, 2)
Args |
replica_id_in_sync_group
|
the current replica_id, should be an int in
[0,num_replicas_in_sync ).
|
num_replicas_in_sync
|
the number of replicas that are in sync.
|
Attributes |
num_replicas_in_sync
|
Returns the number of compute replicas in sync.
|
replica_id_in_sync_group
|
Returns the replica ID.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2021-02-18 UTC.
[null,null,["Last updated 2021-02-18 UTC."],[],[]]