View source on GitHub
|
An iterator over tf.distribute.DistributedDataset.
tf.distribute.DistributedIterator is the primary mechanism for enumerating
elements of a tf.distribute.DistributedDataset. It supports the Python
Iterator protocol, which means it can be iterated over using a for-loop or by
fetching individual elements explicitly via get_next().
You can create a tf.distribute.DistributedIterator by calling iter on
a tf.distribute.DistributedDataset or creating a python loop over a
tf.distribute.DistributedDataset.
Visit the tutorial on distributed input for more examples and caveats.
Attributes | |
|---|---|
element_spec
|
The type specification of an element of tf.distribute.DistributedIterator.
The above example corresponds to the case where you have only one device. If you have two devices, for example, Then the final line will print out: |
Methods
get_next
get_next()
Returns the next input from the iterator for all replicas.
Example use:
strategy = tf.distribute.MirroredStrategy()dataset = tf.data.Dataset.range(100).batch(2)dist_dataset = strategy.experimental_distribute_dataset(dataset)dist_dataset_iterator = iter(dist_dataset)@tf.functiondef one_step(input):return inputstep_num = 5for _ in range(step_num):strategy.run(one_step, args=(dist_dataset_iterator.get_next(),))strategy.experimental_local_results(dist_dataset_iterator.get_next())(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([10, 11])>,)
The above example corresponds to the case where you have only one device. If you have two devices, for example,
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
Then the final line will print out:
(<tf.Tensor: shape=(1,), dtype=int64, numpy=array([10])>,
<tf.Tensor: shape=(1,), dtype=int64, numpy=array([11])>)
| Returns | |
|---|---|
A single tf.Tensor or a tf.distribute.DistributedValues which contains
the next input for all replicas.
|
| Raises | |
|---|---|
tf.errors.OutOfRangeError: If the end of the iterator has been reached.
|
get_next_as_optional
get_next_as_optional()
Returns a tf.experimental.Optional that contains the next value for all replicas.
If the tf.distribute.DistributedIterator has reached the end of the
sequence, the returned tf.experimental.Optional will have no value.
Example usage:
strategy = tf.distribute.MirroredStrategy()global_batch_size = 2steps_per_loop = 2dataset = tf.data.Dataset.range(10).batch(global_batch_size)distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))def step_fn(x):return x@tf.functiondef train_fn(distributed_iterator):for _ in tf.range(steps_per_loop):optional_data = distributed_iterator.get_next_as_optional()if not optional_data.has_value():breaktf.print(strategy.run(step_fn, args=(optional_data.get_value(),)))train_fn(distributed_iterator)# ([0 1],)# ([2 3],)
| Returns | |
|---|---|
An tf.experimental.Optional object representing the next value from the
tf.distribute.DistributedIterator (if it has one) or no value.
|
__iter__
__iter__()
View source on GitHub