|  View source on GitHub | 
Represents a dataset distributed among devices and machines.
A tf.distribute.DistributedDataset could be thought of as a "distributed"
dataset. When you use tf.distribute API to scale training to multiple
devices or machines, you also need to distribute the input data, which leads
to a tf.distribute.DistributedDataset instance, instead of a
tf.data.Dataset instance in the non-distributed case. In TF 2.x,
tf.distribute.DistributedDataset objects are Python iterables.
There are two APIs to create a tf.distribute.DistributedDataset object:
tf.distribute.Strategy.experimental_distribute_dataset(dataset)and
tf.distribute.Strategy.distribute_datasets_from_function(dataset_fn).
When to use which? When you have a tf.data.Dataset instance, and the
regular batch splitting (i.e. re-batch the input tf.data.Dataset instance
with a new batch size that is equal to the global batch size divided by the
number of replicas in sync) and autosharding (i.e. the
tf.data.experimental.AutoShardPolicy options) work for you, use the former
API. Otherwise, if you are not using a canonical tf.data.Dataset instance,
or you would like to customize the batch splitting or sharding, you can wrap
these logic in a dataset_fn and use the latter API. Both API handles
prefetch to device for the user. For more details and examples, follow the
links to the APIs.
There are two main usages of a DistributedDataset object:
- Iterate over it to generate the input for a single device or multiple devices, which is a - tf.distribute.DistributedValuesinstance. To do this, you can:- use a pythonic for-loop construct:
 - global_batch_size = 4- strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])- dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(4).batch(global_batch_size)- dist_dataset = strategy.experimental_distribute_dataset(dataset)- @tf.function- def train_step(input):- features, labels = input- return labels - 0.3 * features- for x in dist_dataset:- # train_step trains the model using the dataset elements- loss = strategy.run(train_step, args=(x,))- print("Loss is", loss)- Loss is PerReplica:{- 0: tf.Tensor(- [[0.7]- [0.7]], shape=(2, 1), dtype=float32),- 1: tf.Tensor(- [[0.7]- [0.7]], shape=(2, 1), dtype=float32)- }- Placing the loop inside a - tf.functionwill give a performance boost. However- breakand- returnare currently not supported if the loop is placed inside a- tf.function. We also don't support placing the loop inside a- tf.functionwhen using- tf.distribute.experimental.MultiWorkerMirroredStrategyor- tf.distribute.experimental.TPUStrategywith multiple workers.- use __iter__to create an explicit iterator, which is of typetf.distribute.DistributedIterator
 - global_batch_size = 4- strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])- train_dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(50).batch(global_batch_size)- train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)- @tf.function- def distributed_train_step(dataset_inputs):- def train_step(input):- loss = tf.constant(0.1)- return loss- per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))- return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,axis=None)- EPOCHS = 2- STEPS = 3- for epoch in range(EPOCHS):- total_loss = 0.0- num_batches = 0- dist_dataset_iterator = iter(train_dist_dataset)- for _ in range(STEPS):- total_loss += distributed_train_step(next(dist_dataset_iterator))- num_batches += 1- average_train_loss = total_loss / num_batches- template = ("Epoch {}, Loss: {:.4f}")- print (template.format(epoch+1, average_train_loss))- Epoch 1, Loss: 0.2000- Epoch 2, Loss: 0.2000- To achieve a performance improvement, you can also wrap the - strategy.runcall with a- tf.rangeinside a- tf.function. This runs multiple steps in a- tf.function. Autograph will convert it to a- tf.while_loopon the worker. However, it is less flexible comparing with running a single step inside- tf.function. For example, you cannot run things eagerly or arbitrary python code within the steps.
- Inspect the - tf.TypeSpecof the data generated by- DistributedDataset.- tf.distribute.DistributedDatasetgenerates- tf.distribute.DistributedValuesas input to the devices. If you pass the input to a- tf.functionand would like to specify the shape and type of each Tensor argument to the function, you can pass a- tf.TypeSpecobject to the- input_signatureargument of the- tf.function. To get the- tf.TypeSpecof the input, you can use the- element_specproperty of the- tf.distribute.DistributedDatasetor- tf.distribute.DistributedIteratorobject.- For example: - global_batch_size = 4- epochs = 1- steps_per_epoch = 1- mirrored_strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])- dataset = tf.data.Dataset.from_tensors(([2.])).repeat(100).batch(global_batch_size)- dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)- @tf.function(input_signature=[dist_dataset.element_spec])- def train_step(per_replica_inputs):- def step_fn(inputs):- return tf.square(inputs)- return mirrored_strategy.run(step_fn, args=(per_replica_inputs,))- for _ in range(epochs):- iterator = iter(dist_dataset)- for _ in range(steps_per_epoch):- output = train_step(next(iterator))- print(output)- PerReplica:{- 0: tf.Tensor(- [[4.]- [4.]], shape=(2, 1), dtype=float32),- 1: tf.Tensor(- [[4.]- [4.]], shape=(2, 1), dtype=float32)- }
Visit the tutorial on distributed input for more examples and caveats.
| Attributes | |
|---|---|
| element_spec | The type specification of an element of this tf.distribute.DistributedDataset.
 | 
Methods
__iter__
__iter__()
Creates an iterator for the tf.distribute.DistributedDataset.
The returned iterator implements the Python Iterator protocol.
Example usage:
global_batch_size = 4strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).repeat().batch(global_batch_size)distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))print(next(distributed_iterator))PerReplica:{0: tf.Tensor([1 2], shape=(2,), dtype=int32),1: tf.Tensor([3 4], shape=(2,), dtype=int32)}
| Returns | |
|---|---|
| An tf.distribute.DistributedIteratorinstance for the giventf.distribute.DistributedDatasetobject to enumerate over the
distributed data. |