tf.distribute.DistributedIterator
Stay organized with collections
Save and categorize content based on your preferences.
An iterator over tf.distribute.DistributedDataset
.
tf.distribute.DistributedIterator
is the primary mechanism for enumerating
elements of a tf.distribute.DistributedDataset
. It supports the Python
Iterator protocol, which means it can be iterated over using a for-loop or by
fetching individual elements explicitly via get_next()
.
You can create a tf.distribute.DistributedIterator
by calling iter
on
a tf.distribute.DistributedDataset
or creating a python loop over a
tf.distribute.DistributedDataset
.
Visit the tutorial
on distributed input for more examples and caveats.
Attributes |
element_spec
|
The type specification of an element of tf.distribute.DistributedIterator .
global_batch_size = 16
strategy = tf.distribute.MirroredStrategy()
dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size)
distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_iterator.element_spec
(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))
The above example corresponds to the case where you have only one device. If
you have two devices, for example,
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
Then the final line will print out:
(PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))
|
Methods
get_next
View source
get_next()
Returns the next input from the iterator for all replicas.
Example use:
strategy = tf.distribute.MirroredStrategy()
dataset = tf.data.Dataset.range(100).batch(2)
dist_dataset = strategy.experimental_distribute_dataset(dataset)
dist_dataset_iterator = iter(dist_dataset)
@tf.function
def one_step(input):
return input
step_num = 5
for _ in range(step_num):
strategy.run(one_step, args=(dist_dataset_iterator.get_next(),))
strategy.experimental_local_results(dist_dataset_iterator.get_next())
(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([10, 11])>,)
The above example corresponds to the case where you have only one device. If
you have two devices, for example,
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
Then the final line will print out:
(<tf.Tensor: shape=(1,), dtype=int64, numpy=array([10])>,
<tf.Tensor: shape=(1,), dtype=int64, numpy=array([11])>)
get_next_as_optional
View source
get_next_as_optional()
Returns a tf.experimental.Optional
that contains the next value for all replicas.
If the tf.distribute.DistributedIterator
has reached the end of the
sequence, the returned tf.experimental.Optional
will have no value.
Example usage:
strategy = tf.distribute.MirroredStrategy()
global_batch_size = 2
steps_per_loop = 2
dataset = tf.data.Dataset.range(10).batch(global_batch_size)
distributed_iterator = iter(
strategy.experimental_distribute_dataset(dataset))
def step_fn(x):
return x
@tf.function
def train_fn(distributed_iterator):
for _ in tf.range(steps_per_loop):
optional_data = distributed_iterator.get_next_as_optional()
if not optional_data.has_value():
break
tf.print(strategy.run(step_fn, args=(optional_data.get_value(),)))
train_fn(distributed_iterator)
# ([0 1],)
# ([2 3],)
__iter__
__iter__()
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2020-10-01 UTC.
[null,null,["Last updated 2020-10-01 UTC."],[],[],null,["# tf.distribute.DistributedIterator\n\n\u003cbr /\u003e\n\n|-----------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/distribute/input_lib.py#L146-L271) |\n\nAn iterator over [`tf.distribute.DistributedDataset`](../../tf/distribute/DistributedDataset).\n\n[`tf.distribute.DistributedIterator`](../../tf/distribute/DistributedIterator) is the primary mechanism for enumerating\nelements of a [`tf.distribute.DistributedDataset`](../../tf/distribute/DistributedDataset). It supports the Python\nIterator protocol, which means it can be iterated over using a for-loop or by\nfetching individual elements explicitly via `get_next()`.\n\nYou can create a [`tf.distribute.DistributedIterator`](../../tf/distribute/DistributedIterator) by calling `iter` on\na [`tf.distribute.DistributedDataset`](../../tf/distribute/DistributedDataset) or creating a python loop over a\n[`tf.distribute.DistributedDataset`](../../tf/distribute/DistributedDataset).\n\nVisit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input)\non distributed input for more examples and caveats.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|----------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `element_spec` | The type specification of an element of [`tf.distribute.DistributedIterator`](../../tf/distribute/DistributedIterator). \u003cbr /\u003e global_batch_size = 16 strategy = tf.distribute.MirroredStrategy() dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size) distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset)) distributed_iterator.element_spec (TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)) The above example corresponds to the case where you have only one device. If you have two devices, for example, strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1']) Then the final line will print out: (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)), PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None), TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))) \u003cbr /\u003e |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `get_next`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/distribute/input_lib.py#L163-L200) \n\n get_next()\n\nReturns the next input from the iterator for all replicas.\n\n#### Example use:\n\n strategy = tf.distribute.MirroredStrategy()\n dataset = tf.data.Dataset.range(100).batch(2)\n dist_dataset = strategy.experimental_distribute_dataset(dataset)\n dist_dataset_iterator = iter(dist_dataset)\n @tf.function\n def one_step(input):\n return input\n step_num = 5\n for _ in range(step_num):\n strategy.run(one_step, args=(dist_dataset_iterator.get_next(),))\n strategy.experimental_local_results(dist_dataset_iterator.get_next())\n (\u003ctf.Tensor: shape=(2,), dtype=int64, numpy=array([10, 11])\u003e,)\n\nThe above example corresponds to the case where you have only one device. If\nyou have two devices, for example, \n\n strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])\n\nThen the final line will print out: \n\n (\u003ctf.Tensor: shape=(1,), dtype=int64, numpy=array([10])\u003e,\n \u003ctf.Tensor: shape=(1,), dtype=int64, numpy=array([11])\u003e)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A single [`tf.Tensor`](../../tf/Tensor) or a [`tf.distribute.DistributedValues`](../../tf/distribute/DistributedValues) which contains the next input for all replicas. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ||\n|---|---|\n| [`tf.errors.OutOfRangeError`](../../tf/errors/OutOfRangeError): If the end of the iterator has been reached. ||\n\n\u003cbr /\u003e\n\n### `get_next_as_optional`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/distribute/input_lib.py#L239-L271) \n\n get_next_as_optional()\n\nReturns a [`tf.experimental.Optional`](../../tf/experimental/Optional) that contains the next value for all replicas.\n\nIf the [`tf.distribute.DistributedIterator`](../../tf/distribute/DistributedIterator) has reached the end of the\nsequence, the returned [`tf.experimental.Optional`](../../tf/experimental/Optional) will have no value.\n\n#### Example usage:\n\n strategy = tf.distribute.MirroredStrategy()\n global_batch_size = 2\n steps_per_loop = 2\n dataset = tf.data.Dataset.range(10).batch(global_batch_size)\n distributed_iterator = iter(\n strategy.experimental_distribute_dataset(dataset))\n def step_fn(x):\n return x\n @tf.function\n def train_fn(distributed_iterator):\n for _ in tf.range(steps_per_loop):\n optional_data = distributed_iterator.get_next_as_optional()\n if not optional_data.has_value():\n break\n tf.print(strategy.run(step_fn, args=(optional_data.get_value(),)))\n train_fn(distributed_iterator)\n # ([0 1],)\n # ([2 3],)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| An [`tf.experimental.Optional`](../../tf/experimental/Optional) object representing the next value from the [`tf.distribute.DistributedIterator`](../../tf/distribute/DistributedIterator) (if it has one) or no value. ||\n\n\u003cbr /\u003e\n\n### `__iter__`\n\n __iter__()"]]