A class wrapping information needed by a distribute function.
tf.distribute.experimental.ValueContext(
replica_id_in_sync_group=0, num_replicas_in_sync=1
)
This is a context class that is passed to the value_fn
in
strategy.experimental_distribute_values_from_function
and contains
information about the compute replicas. The num_replicas_in_sync
and
replica_id
can be used to customize the value on each replica.
Example usage:
Directly constructed.
def value_fn(context):
return context.replica_id_in_sync_group/context.num_replicas_in_sync
context = tf.distribute.experimental.ValueContext(
replica_id_in_sync_group=2, num_replicas_in_sync=4)
per_replica_value = value_fn(context)
per_replica_value
0.5
Passed in by
experimental_distribute_values_from_function
.strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
def value_fn(value_context):
return value_context.num_replicas_in_sync
distributed_values = (
strategy.experimental_distribute_values_from_function(
value_fn))
local_result = strategy.experimental_local_results(distributed_values)
local_result
(2, 2)