View source on GitHub
|
A class wrapping information needed by a distribute function.
tf.distribute.experimental.ValueContext(
replica_id_in_sync_group=0, num_replicas_in_sync=1
)
This is a context class that is passed to the value_fn in
strategy.experimental_distribute_values_from_function and contains
information about the compute replicas. The num_replicas_in_sync and
replica_id can be used to customize the value on each replica.
Example usage:
Directly constructed.
def value_fn(context):return context.replica_id_in_sync_group/context.num_replicas_in_synccontext = tf.distribute.experimental.ValueContext(replica_id_in_sync_group=2, num_replicas_in_sync=4)per_replica_value = value_fn(context)per_replica_value0.5Passed in by
experimental_distribute_values_from_function.strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])def value_fn(value_context):return value_context.num_replicas_in_syncdistributed_values = (strategy.experimental_distribute_values_from_function(value_fn))local_result = strategy.experimental_local_results(distributed_values)local_result(2, 2)
Args | |
|---|---|
replica_id_in_sync_group
|
the current replica_id, should be an int in
[0,num_replicas_in_sync).
|
num_replicas_in_sync
|
the number of replicas that are in sync. |
Attributes | |
|---|---|
num_replicas_in_sync
|
Returns the number of compute replicas in sync. |
replica_id_in_sync_group
|
Returns the replica ID. |
View source on GitHub