TensorFlow 2.0 Beta is available Learn more

Load images with tf.data

View on TensorFlow.org Run in Google Colab View source on GitHub

This tutorial provides a simple example of how to load an image dataset using tf.data.

The dataset used in this example is distributed as directories of images, with one class of image per directory.

Setup

from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
tf.enable_eager_execution()
tf.__version__
'1.14.0'
AUTOTUNE = tf.data.experimental.AUTOTUNE

Download and inspect the dataset

Retrieve the images

Before you start any training, you'll need a set of images to teach the network about the new classes you want to recognize. You've created an archive of creative-commons licensed flower photos to use initially.

import pathlib
data_root_orig = tf.keras.utils.get_file('flower_photos',
                                         'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                         untar=True)
data_root = pathlib.Path(data_root_orig)
print(data_root)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 4s 0us/step
/home/kbuilder/.keras/datasets/flower_photos

After downloading 218MB, you should now have a copy of the flower photos available:

for item in data_root.iterdir():
  print(item)
/home/kbuilder/.keras/datasets/flower_photos/LICENSE.txt
/home/kbuilder/.keras/datasets/flower_photos/daisy
/home/kbuilder/.keras/datasets/flower_photos/roses
/home/kbuilder/.keras/datasets/flower_photos/sunflowers
/home/kbuilder/.keras/datasets/flower_photos/dandelion
/home/kbuilder/.keras/datasets/flower_photos/tulips
import random
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
random.shuffle(all_image_paths)

image_count = len(all_image_paths)
image_count
3670
all_image_paths[:10]
['/home/kbuilder/.keras/datasets/flower_photos/tulips/3485767306_6db7bdf536.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/dandelion/2521827947_9d237779bb_n.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/daisy/3661613900_b15ca1d35d_m.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/tulips/8695372372_302135aeb2.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/sunflowers/5020805135_1219d7523d.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/sunflowers/2619000556_6634478e64_n.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/tulips/8673416166_620fc18e2f_n.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/daisy/695778683_890c46ebac.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/tulips/100930342_92e8746431_n.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/dandelion/18215579866_94b1732f24.jpg']

Inspect the images

Now let's have a quick look at a couple of the images, so you know what you're dealing with:

import os
attributions = (data_root/"LICENSE.txt").open(encoding='utf-8').readlines()[4:]
attributions = [line.split(' CC-BY') for line in attributions]
attributions = dict(attributions)
import IPython.display as display

def caption_image(image_path):
    image_rel = pathlib.Path(image_path).relative_to(data_root)
    return "Image (CC BY 2.0) " + ' - '.join(attributions[str(image_rel)].split(' - ')[:-1])
for n in range(3):
  image_path = random.choice(all_image_paths)
  display.display(display.Image(image_path))
  print(caption_image(image_path))
  print()

jpeg

Image (CC BY 2.0)  by Image Catalog

jpeg

Image (CC BY 2.0)  by Umberto Rotundo

jpeg

Image (CC BY 2.0)  by John Tann

Determine the label for each image

List the available labels:

label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
label_names
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

Assign an index to each label:

label_to_index = dict((name, index) for index,name in enumerate(label_names))
label_to_index
{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}

Create a list of every file, and its label index

all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                    for path in all_image_paths]

print("First 10 labels indices: ", all_image_labels[:10])
First 10 labels indices:  [4, 1, 0, 4, 3, 3, 4, 0, 4, 1]

Load and format the images

TensorFlow includes all the tools you need to load and process images:

img_path = all_image_paths[0]
img_path
'/home/kbuilder/.keras/datasets/flower_photos/tulips/3485767306_6db7bdf536.jpg'

here is the raw data:

img_raw = tf.io.read_file(img_path)
print(repr(img_raw)[:100]+"...")
<tf.Tensor: id=1, shape=(), dtype=string, numpy=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00H\...

Decode it into an image tensor:

img_tensor = tf.image.decode_image(img_raw)

print(img_tensor.shape)
print(img_tensor.dtype)
(375, 500, 3)
<dtype: 'uint8'>

Resize it for your model:

img_final = tf.image.resize(img_tensor, [192, 192])
img_final = img_final/255.0
print(img_final.shape)
print(img_final.numpy().min())
print(img_final.numpy().max())
(192, 192, 3)
0.0
1.0

Wrap up these up in simple functions for later.

def preprocess_image(image):
  image = tf.image.decode_jpeg(image, channels=3)
  image = tf.image.resize(image, [192, 192])
  image /= 255.0  # normalize to [0,1] range

  return image
def load_and_preprocess_image(path):
  image = tf.read_file(path)
  return preprocess_image(image)
import matplotlib.pyplot as plt

img_path = all_image_paths[0]
label = all_image_labels[0]

plt.imshow(load_and_preprocess_image(img_path))
plt.grid(False)
plt.xlabel(caption_image(img_path).encode('utf-8'))
plt.title(label_names[label].title())
print()

Build a tf.data.Dataset

A dataset of images

The easiest way to build a tf.data.Dataset is using the from_tensor_slices method.

Slicing the array of strings results in a dataset of strings:

path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)

The output_shapes and output_types fields describe the content of each item in the dataset. In this case it is a set of scalar binary-strings

print('shape: ', repr(path_ds.output_shapes))
print('type: ', path_ds.output_types)
print()
print(path_ds)
WARNING: Logging before flag parsing goes to stderr.
W0708 23:48:10.650363 140414140143360 deprecation.py:323] From <ipython-input-22-2a9400bc986b>:1: DatasetV1.output_shapes (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_shapes(dataset)`.
W0708 23:48:10.651903 140414140143360 deprecation.py:323] From <ipython-input-22-2a9400bc986b>:2: DatasetV1.output_types (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_types(dataset)`.

shape:  TensorShape([])
type:  <dtype: 'string'>

<DatasetV1Adapter shapes: (), types: tf.string>

Now create a new dataset that loads and formats images on the fly by mapping preprocess_image over the dataset of paths.

image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
import matplotlib.pyplot as plt

plt.figure(figsize=(8,8))
for n,image in enumerate(image_ds.take(4)):
  plt.subplot(2,2,n+1)
  plt.imshow(image)
  plt.grid(False)
  plt.xticks([])
  plt.yticks([])
  plt.xlabel(caption_image(all_image_paths[n]))
  plt.show()

png

png

png

png

A dataset of (image, label) pairs

Using the same from_tensor_slices method you can build a dataset of labels

label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))
for label in label_ds.take(10):
  print(label_names[label.numpy()])
tulips
dandelion
daisy
tulips
sunflowers
sunflowers
tulips
daisy
tulips
dandelion

Since the datasets are in the same order you can just zip them together to get a dataset of (image, label) pairs.

image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))

The new dataset's shapes and types are tuples of shapes and types as well, describing each field:

print(image_label_ds)
<DatasetV1Adapter shapes: ((192, 192, 3), ()), types: (tf.float32, tf.int64)>
ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))

# The tuples are unpacked into the positional arguments of the mapped function
def load_and_preprocess_from_path_label(path, label):
  return load_and_preprocess_image(path), label

image_label_ds = ds.map(load_and_preprocess_from_path_label)
image_label_ds
<DatasetV1Adapter shapes: ((192, 192, 3), ()), types: (tf.float32, tf.int32)>

Basic methods for training

To train a model with this dataset you will want the data:

  • To be well shuffled.
  • To be batched.
  • To repeat forever.
  • To have batches available as soon as possible.

These features can be easily added using the tf.data api.

BATCH_SIZE = 32

# Setting a shuffle buffer size as large as the dataset ensures that the data is
# completely shuffled.
ds = image_label_ds.shuffle(buffer_size=image_count)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
# `prefetch` lets the dataset fetch batches, in the background while the model is training.
ds = ds.prefetch(buffer_size=AUTOTUNE)
ds
<DatasetV1Adapter shapes: ((?, 192, 192, 3), (?,)), types: (tf.float32, tf.int32)>

There are a few things to note here:

  1. The order is important.

    • A .shuffle after a .repeat would shuffle items across epoch boundaries (some items will be seen twice before others are seen at all).
    • A .shuffle after a .batch would shuffle the order of the batches, but not shuffle the items across batches.
  2. Use a buffer_size the same size as the dataset for a full shuffle. Up to the dataset size, large values provide better randomization, but use more memory.

  3. The shuffle buffer is filled before any elements are pulled from it. So a large buffer_size may cause a delay when your Dataset is starting.

  4. The shuffled dataset doesn't report the end of a dataset until the shuffle-buffer is completely empty. The Dataset is restarted by .repeat, causing another wait for the shuffle-buffer to be filled.

This last point, as well as the order of .shuffle and .repeat, can be addressed by using the tf.data.Dataset.apply method with the fused tf.data.experimental.shuffle_and_repeat function:

ds = image_label_ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds = ds.batch(BATCH_SIZE)
ds = ds.prefetch(buffer_size=AUTOTUNE)
ds
W0708 23:48:11.303015 140414140143360 deprecation.py:323] From <ipython-input-31-4dc713bd4d84>:2: shuffle_and_repeat (from tensorflow.python.data.experimental.ops.shuffle_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.shuffle(buffer_size, seed)` followed by `tf.data.Dataset.repeat(count)`. Static tf.data optimizations will take care of using the fused implementation.

<DatasetV1Adapter shapes: ((?, 192, 192, 3), (?,)), types: (tf.float32, tf.int32)>
  • For more on ordering the operations, see Repeat and Shuffle in the Input Pipeline Performance guide.

Pipe the dataset to a model

Fetch a copy of MobileNet v2 from tf.keras.applications.

This will be used for a simple transfer learning example.

Set the MobileNet weights to be non-trainable:

mobile_net = tf.keras.applications.MobileNetV2(input_shape=(192, 192, 3), include_top=False)
mobile_net.trainable=False
Downloading data from https://github.com/JonathanCMitchell/mobilenet_v2_keras/releases/download/v1.1/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_192_no_top.h5
9412608/9406464 [==============================] - 2s 0us/step

This model expects its input to be normalized to the [-1,1] range:

help(keras_applications.mobilenet_v2.preprocess_input)
...
This function applies the "Inception" preprocessing which converts
the RGB values from [0, 255] to [-1, 1]
...

So before passing data to the MobileNet model, you need to convert the input from a range of [0,1] to [-1,1].

def change_range(image,label):
  return 2*image-1, label

keras_ds = ds.map(change_range)

The MobileNet returns a 6x6 spatial grid of features for each image.

Pass it a batch of images to see:

# The dataset may take a few seconds to start, as it fills its shuffle buffer.
image_batch, label_batch = next(iter(keras_ds))
feature_map_batch = mobile_net(image_batch)
print(feature_map_batch.shape)
(32, 6, 6, 1280)

Because of this output shape, build a model wrapped around MobileNet using tf.keras.layers.GlobalAveragePooling2D to average over the space dimensions before the output tf.keras.layers.Dense layer:

model = tf.keras.Sequential([
  mobile_net,
  tf.keras.layers.GlobalAveragePooling2D(),
  tf.keras.layers.Dense(len(label_names))])

Now it produces outputs of the expected shape:

logit_batch = model(image_batch).numpy()

print("min logit:", logit_batch.min())
print("max logit:", logit_batch.max())
print()

print("Shape:", logit_batch.shape)
min logit: -3.1316526
max logit: 2.843516

Shape: (32, 5)

Compile the model to describe the training procedure:

model.compile(optimizer=tf.train.AdamOptimizer(),
              loss=tf.keras.losses.sparse_categorical_crossentropy,
              metrics=["accuracy"])

There are 2 trainable variables: the Dense weights and bias:

len(model.trainable_variables)
2
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mobilenetv2_1.00_192 (Model) (None, 6, 6, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________

Train the model.

Normally you would specify the real number of steps per epoch, but for demonstration purposes only run 3 steps.

steps_per_epoch=tf.ceil(len(all_image_paths)/BATCH_SIZE).numpy()
steps_per_epoch
115.0
model.fit(ds, epochs=1, steps_per_epoch=3)
W0708 23:48:40.032137 140414140143360 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

3/3 [==============================] - 10s 3s/step - loss: 8.1482 - acc: 0.2188

<tensorflow.python.keras.callbacks.History at 0x7fb3ac7ece80>

Performance

The simple pipeline used above reads each file individually, on each epoch. This is fine for local training on CPU but may not be sufficient for GPU training, and is totally inappropriate for any sort of distributed training.

To investigate, first build a simple function to check the performance of our datasets:

import time

def timeit(ds, batches=2*steps_per_epoch+1):
  overall_start = time.time()
  # Fetch a single batch to prime the pipeline (fill the shuffle buffer),
  # before starting the timer
  it = iter(ds.take(batches+1))
  next(it)

  start = time.time()
  for i,(images,labels) in enumerate(it):
    if i%10 == 0:
      print('.',end='')
  print()
  end = time.time()

  duration = end-start
  print("{} batches: {} s".format(batches, duration))
  print("{:0.5f} Images/s".format(BATCH_SIZE*batches/duration))
  print("Total time: {}s".format(end-overall_start))

The performance of the current dataset is:

ds = image_label_ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds = ds.batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
ds
<DatasetV1Adapter shapes: ((?, 192, 192, 3), (?,)), types: (tf.float32, tf.int32)>
timeit(ds)
........................
231.0 batches: 16.362396478652954 s
451.76756 Images/s
Total time: 24.477432012557983s

Cache

Use tf.data.Dataset.cache to easily cache calculations across epochs. This is especially performant if the data fits in memory.

Here the images are cached, after being pre-precessed (decoded and resized):

ds = image_label_ds.cache()
ds = ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds = ds.batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
ds
<DatasetV1Adapter shapes: ((?, 192, 192, 3), (?,)), types: (tf.float32, tf.int32)>
timeit(ds)
........................
231.0 batches: 0.6933557987213135 s
10661.19302 Images/s
Total time: 8.485734939575195s

One disadvantage to using an in-memory cache is that the cache must be rebuilt on each run, giving the same startup delay each time the dataset is started:

timeit(ds)
........................
231.0 batches: 0.9749493598937988 s
7581.93226 Images/s
Total time: 9.48345398902893s

If the data doesn't fit in memory, use a cache file:

ds = image_label_ds.cache(filename='./cache.tf-data')
ds = ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds = ds.batch(BATCH_SIZE).prefetch(1)
ds
<DatasetV1Adapter shapes: ((?, 192, 192, 3), (?,)), types: (tf.float32, tf.int32)>
timeit(ds)
........................
231.0 batches: 3.659311294555664 s
2020.05225 Images/s
Total time: 13.964630126953125s

The cache file also has the advantage that it can be used to quickly restart the dataset without rebuilding the cache. Note how much faster it is the second time:

timeit(ds)
........................
231.0 batches: 2.9473583698272705 s
2508.00855 Images/s
Total time: 4.52703332901001s

TFRecord File

Raw image data

TFRecord files are a simple format for storing a sequence of binary blobs. By packing multiple examples into the same file, TensorFlow is able to read multiple examples at once, which is especially important for performance when using a remote storage service such as GCS.

First, build a TFRecord file from the raw image data:

image_ds = tf.data.Dataset.from_tensor_slices(all_image_paths).map(tf.read_file)
tfrec = tf.data.experimental.TFRecordWriter('images.tfrec')
tfrec.write(image_ds)

Next build a dataset that reads from the TFRecord file and decodes/reformats the images using the preprocess_image function you defined earlier.

image_ds = tf.data.TFRecordDataset('images.tfrec').map(preprocess_image)

Zip that with the labels dataset you defined earlier, to get the expected (image,label) pairs.

ds = tf.data.Dataset.zip((image_ds, label_ds))
ds = ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds=ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
ds
<DatasetV1Adapter shapes: ((?, 192, 192, 3), (?,)), types: (tf.float32, tf.int64)>
timeit(ds)
........................
231.0 batches: 15.514223337173462 s
476.46600 Images/s
Total time: 22.746591806411743s

This is slower than the cache version because you have not cached the preprocessing.

Serialized Tensors

To save some preprocessing to the TFRecord file, first make a dataset of the processed images, as before:

paths_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = paths_ds.map(load_and_preprocess_image)
image_ds
<DatasetV1Adapter shapes: (192, 192, 3), types: tf.float32>

Now instead of a dataset of .jpeg strings, this is a dataset of tensors.

To serialize this to a TFRecord file you first convert the dataset of tensors to a dataset of strings.

ds = image_ds.map(tf.serialize_tensor)
ds
<DatasetV1Adapter shapes: (), types: tf.string>
tfrec = tf.data.experimental.TFRecordWriter('images.tfrec')
tfrec.write(ds)

With the preprocessing cached, data can be loaded from the TFRecord file quite efficiently. Just remember to de-serialize the tensor before trying to use it.

ds = tf.data.TFRecordDataset('images.tfrec')

def parse(x):
  result = tf.parse_tensor(x, out_type=tf.float32)
  result = tf.reshape(result, [192, 192, 3])
  return result

ds = ds.map(parse, num_parallel_calls=AUTOTUNE)
ds
<DatasetV1Adapter shapes: (192, 192, 3), types: tf.float32>

Now, add the labels and apply the same standard operations as before:

ds = tf.data.Dataset.zip((ds, label_ds))
ds = ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds=ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
ds
<DatasetV1Adapter shapes: ((?, 192, 192, 3), (?,)), types: (tf.float32, tf.int64)>
timeit(ds)
........................
231.0 batches: 2.910050630569458 s
2540.16199 Images/s
Total time: 9.230100393295288s