Creates a new, uninitialized Iterator based on the given handle.
This method allows you to define a "feedable" iterator where you can choose
between concrete iterators by feeding a value in a tf.Session.run call.
In that case, string_handle would be a tf.compat.v1.placeholder, and you
would
feed it with the value of tf.data.Iterator.string_handle in each step.
For example, if you had two iterators that marked the current position in
a training dataset and a test dataset, you could choose which to use in
each step as follows:
A scalar tf.Tensor of type tf.string that evaluates to
a handle produced by the Iterator.string_handle() method.
output_types
A (nested) structure of tf.DType objects corresponding to
each component of an element of this dataset.
output_shapes
(Optional.) A (nested) structure of tf.TensorShape
objects corresponding to each component of an element of this dataset.
If omitted, each component will have an unconstrainted shape.
output_classes
(Optional.) A (nested) structure of Python type objects
corresponding to each component of an element of this iterator. If
omitted, each component is assumed to be of type tf.Tensor.
Creates a new, uninitialized Iterator with the given structure.
This iterator-constructing method can be used to create an iterator that
is reusable with many different datasets.
The returned iterator is not bound to a particular dataset, and it has
no initializer. To initialize the iterator, run the operation returned by
Iterator.make_initializer(dataset).
The following is an example
iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))
dataset_range = Dataset.range(10)
range_initializer = iterator.make_initializer(dataset_range)
dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
evens_initializer = iterator.make_initializer(dataset_evens)
# Define a model based on the iterator; in this example, the model_fn
# is expected to take scalar tf.int64 Tensors as input (see
# the definition of 'iterator' above).
prediction, loss = model_fn(iterator.get_next())
# Train for `num_epochs`, where for each epoch, we first iterate over
# dataset_range, and then iterate over dataset_evens.
for _ in range(num_epochs):
# Initialize the iterator to `dataset_range`
sess.run(range_initializer)
while True:
try:
pred, loss_val = sess.run([prediction, loss])
except tf.errors.OutOfRangeError:
break
# Initialize the iterator to `dataset_evens`
sess.run(evens_initializer)
while True:
try:
pred, loss_val = sess.run([prediction, loss])
except tf.errors.OutOfRangeError:
break
Args
output_types
A (nested) structure of tf.DType objects corresponding to
each component of an element of this dataset.
output_shapes
(Optional.) A (nested) structure of tf.TensorShape
objects corresponding to each component of an element of this dataset.
If omitted, each component will have an unconstrainted shape.
shared_name
(Optional.) If non-empty, this iterator will be shared under
the given name across multiple sessions that share the same devices
(e.g. when using a remote server).
output_classes
(Optional.) A (nested) structure of Python type objects
corresponding to each component of an element of this iterator. If
omitted, each component is assumed to be of type tf.Tensor.
Returns
An Iterator.
Raises
TypeError
If the structures of output_shapes and output_types are
not the same.
In graph mode, you should typically call this method once and use its
result as the input to another computation. A typical loop will then call
tf.Session.run on the result of that computation. The loop will terminate
when the Iterator.get_next() operation raises
tf.errors.OutOfRangeError. The following skeleton shows how to use
this method when building a training loop:
dataset = ... # A `tf.data.Dataset` object.
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Build a TensorFlow graph that does something with each element.
loss = model_function(next_element)
optimizer = ... # A `tf.compat.v1.train.Optimizer` object.
train_op = optimizer.minimize(loss)
with tf.compat.v1.Session() as sess:
try:
while True:
sess.run(train_op)
except tf.errors.OutOfRangeError:
pass