View source on GitHub
|
Base class for representing distributed values.
A subclass instance of tf.distribute.DistributedValues is created when
creating variables within a distribution strategy, iterating a
tf.distribute.DistributedDataset or through tf.distribute.Strategy.run.
This base class should never be instantiated directly.
tf.distribute.DistributedValues contains a value per replica. Depending on
the subclass, the values could either be synced on update, synced on demand,
or never synced.
Two representative types of tf.distribute.DistributedValues are
tf.types.experimental.PerReplica and tf.types.experimental.Mirrored
values.
PerReplica values exist on the worker devices, with a different value for
each replica. They are produced by iterating through a distributed dataset
returned by tf.distribute.Strategy.experimental_distribute_dataset (Example
1, below) and tf.distribute.Strategy.distribute_datasets_from_function. They
are also the typical result returned by tf.distribute.Strategy.run (Example
2).
Mirrored values are like PerReplica values, except we know that the value
on all replicas are the same. Mirrored values are kept synchronized by the
distribution strategy in use, while PerReplica values are left
unsynchronized. Mirrored values typically represent model weights. We can
safely read a Mirrored value in a cross-replica context by using the value
on any replica, while PerReplica values should not be read or manipulated in
a cross-replica context."
tf.distribute.DistributedValues can be reduced via strategy.reduce to
obtain a single value across replicas (Example 4), used as input into
tf.distribute.Strategy.run (Example 3), or collected to inspect the
per-replica values using tf.distribute.Strategy.experimental_local_results
(Example 5).
Example usages:
- Created from a
tf.distribute.DistributedDataset:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))distributed_values = next(dataset_iterator)distributed_valuesPerReplica:{0: <tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>,1: <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>}
- Returned by
run:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])@tf.functiondef run():ctx = tf.distribute.get_replica_context()return ctx.replica_id_in_sync_groupdistributed_values = strategy.run(run)distributed_valuesPerReplica:{0: <tf.Tensor: shape=(), dtype=int32, numpy=0>,1: <tf.Tensor: shape=(), dtype=int32, numpy=1>}
- As input into
run:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))distributed_values = next(dataset_iterator)@tf.functiondef run(input):return input + 1.0updated_value = strategy.run(run, args=(distributed_values,))updated_valuePerReplica:{0: <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>,1: <tf.Tensor: shape=(1,), dtype=float32, numpy=array([7.], dtype=float32)>}
- As input into
reduce:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))distributed_values = next(dataset_iterator)reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM,distributed_values,axis = 0)reduced_value<tf.Tensor: shape=(), dtype=float32, numpy=11.0>
- How to inspect per-replica values locally:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))per_replica_values = strategy.experimental_local_results(distributed_values)per_replica_values(<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>,<tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>)
View source on GitHub