TensorFlow Datasets

TFDS provides a collection of ready-to-use datasets for use with TensorFlow, Jax, and other Machine Learning frameworks.

It handles downloading and preparing the data deterministically and constructing a tf.data.Dataset (or np.array).

View on TensorFlow.org View source on GitHub

Installation

TFDS exists in two packages:

  • pip install tensorflow-datasets: The stable version, released every few months.
  • pip install tfds-nightly: Released every day, contains the last versions of the datasets.

This colab uses tfds-nightly:

pip install -q tfds-nightly tensorflow matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

Find available datasets

All dataset builders are subclass of tfds.core.DatasetBuilder. To get the list of available builders, use tfds.list_builders() or look at our catalog.

tfds.list_builders()
['abstract_reasoning',
 'accentdb',
 'aeslc',
 'aflw2k3d',
 'ag_news_subset',
 'ai2_arc',
 'ai2_arc_with_ir',
 'amazon_us_reviews',
 'anli',
 'arc',
 'bair_robot_pushing_small',
 'bccd',
 'beans',
 'big_patent',
 'bigearthnet',
 'billsum',
 'binarized_mnist',
 'binary_alpha_digits',
 'blimp',
 'bool_q',
 'c4',
 'caltech101',
 'caltech_birds2010',
 'caltech_birds2011',
 'cars196',
 'cassava',
 'cats_vs_dogs',
 'celeb_a',
 'celeb_a_hq',
 'cfq',
 'chexpert',
 'cifar10',
 'cifar100',
 'cifar10_1',
 'cifar10_corrupted',
 'citrus_leaves',
 'cityscapes',
 'civil_comments',
 'clevr',
 'clic',
 'clinc_oos',
 'cmaterdb',
 'cnn_dailymail',
 'coco',
 'coco_captions',
 'coil100',
 'colorectal_histology',
 'colorectal_histology_large',
 'common_voice',
 'coqa',
 'cos_e',
 'cosmos_qa',
 'covid19sum',
 'crema_d',
 'curated_breast_imaging_ddsm',
 'cycle_gan',
 'deep_weeds',
 'definite_pronoun_resolution',
 'dementiabank',
 'diabetic_retinopathy_detection',
 'div2k',
 'dmlab',
 'downsampled_imagenet',
 'dsprites',
 'dtd',
 'duke_ultrasound',
 'e2e_cleaned',
 'emnist',
 'eraser_multi_rc',
 'esnli',
 'eurosat',
 'fashion_mnist',
 'flic',
 'flores',
 'food101',
 'forest_fires',
 'fuss',
 'gap',
 'geirhos_conflict_stimuli',
 'genomics_ood',
 'german_credit_numeric',
 'gigaword',
 'glue',
 'goemotions',
 'gpt3',
 'groove',
 'gtzan',
 'gtzan_music_speech',
 'hellaswag',
 'higgs',
 'horses_or_humans',
 'howell',
 'i_naturalist2017',
 'imagenet2012',
 'imagenet2012_corrupted',
 'imagenet2012_real',
 'imagenet2012_subset',
 'imagenet_a',
 'imagenet_r',
 'imagenet_resized',
 'imagenet_v2',
 'imagenette',
 'imagewang',
 'imdb_reviews',
 'irc_disentanglement',
 'iris',
 'kitti',
 'kmnist',
 'lambada',
 'lfw',
 'librispeech',
 'librispeech_lm',
 'libritts',
 'ljspeech',
 'lm1b',
 'lost_and_found',
 'lsun',
 'malaria',
 'math_dataset',
 'mctaco',
 'mlqa',
 'mnist',
 'mnist_corrupted',
 'movie_lens',
 'movie_rationales',
 'movielens',
 'moving_mnist',
 'multi_news',
 'multi_nli',
 'multi_nli_mismatch',
 'natural_questions',
 'natural_questions_open',
 'newsroom',
 'nsynth',
 'nyu_depth_v2',
 'omniglot',
 'open_images_challenge2019_detection',
 'open_images_v4',
 'openbookqa',
 'opinion_abstracts',
 'opinosis',
 'opus',
 'oxford_flowers102',
 'oxford_iiit_pet',
 'para_crawl',
 'patch_camelyon',
 'paws_wiki',
 'paws_x_wiki',
 'pet_finder',
 'pg19',
 'piqa',
 'places365_small',
 'plant_leaves',
 'plant_village',
 'plantae_k',
 'qa4mre',
 'qasc',
 'quac',
 'quickdraw_bitmap',
 'race',
 'radon',
 'reddit',
 'reddit_disentanglement',
 'reddit_tifu',
 'resisc45',
 'robonet',
 'rock_paper_scissors',
 'rock_you',
 'salient_span_wikipedia',
 'samsum',
 'savee',
 'scan',
 'scene_parse150',
 'scicite',
 'scientific_papers',
 'sentiment140',
 'shapes3d',
 'siscore',
 'smallnorb',
 'snli',
 'so2sat',
 'speech_commands',
 'spoken_digit',
 'squad',
 'stanford_dogs',
 'stanford_online_products',
 'starcraft_video',
 'stl10',
 'story_cloze',
 'sun397',
 'super_glue',
 'svhn_cropped',
 'ted_hrlr_translate',
 'ted_multi_translate',
 'tedlium',
 'tf_flowers',
 'the300w_lp',
 'tiny_shakespeare',
 'titanic',
 'trec',
 'trivia_qa',
 'tydi_qa',
 'uc_merced',
 'ucf101',
 'vctk',
 'vgg_face2',
 'visual_domain_decathlon',
 'voc',
 'voxceleb',
 'voxforge',
 'waymo_open_dataset',
 'web_nlg',
 'web_questions',
 'wider_face',
 'wiki40b',
 'wiki_bio',
 'wikihow',
 'wikipedia',
 'wikipedia_toxicity_subtypes',
 'wine_quality',
 'winogrande',
 'wmt14_translate',
 'wmt15_translate',
 'wmt16_translate',
 'wmt17_translate',
 'wmt18_translate',
 'wmt19_translate',
 'wmt_t2t_translate',
 'wmt_translate',
 'wordnet',
 'wsc273',
 'xnli',
 'xquad',
 'xsum',
 'xtreme_pawsx',
 'xtreme_xnli',
 'yelp_polarity_reviews',
 'yes_no']

Load a dataset

tfds.load

The easiest way of loading a dataset is tfds.load. It will:

  1. Download the data and save it as tfrecord files.
  2. Load the tfrecord and create the tf.data.Dataset.
ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.


Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

Some common arguments:

  • split=: Which split to read (e.g. 'train', ['train', 'test'], 'train[80%:]',...). See our split API guide.
  • shuffle_files=: Control whether to shuffle the files between each epoch (TFDS store big datasets in multiple smaller files).
  • data_dir=: Location where the dataset is saved ( defaults to ~/tensorflow_datasets/)
  • with_info=True: Returns the tfds.core.DatasetInfo containing dataset metadata
  • download=False: Disable download

tfds.builder

tfds.load is a thin wrapper around tfds.core.DatasetBuilder. You can get the same output using the tfds.core.DatasetBuilder API:

builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)
<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

tfds build CLI

If you want to generate a specific dataset, you can use the tfds command line. For example:

tfds build mnist

See the doc for available flags.

Iterate over a dataset

As dict

By default, the tf.data.Dataset object contains a dict of tf.Tensors:

ds = tfds.load('mnist', split='train')
ds = ds.take(1)  # Only take a single example

for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
  print(list(example.keys()))
  image = example["image"]
  label = example["label"]
  print(image.shape, label)
['image', 'label']
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)

As tuple (as_supervised=True)

By using as_supervised=True, you can get a tuple (features, label) instead for supervised datasets.

ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in ds:  # example is (image, label)
  print(image.shape, label)
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)

As numpy (tfds.as_numpy)

Uses tfds.as_numpy to convert:

ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in tfds.as_numpy(ds):
  print(type(image), type(label), label)
<class 'numpy.ndarray'> <class 'numpy.int64'> 4

As batched tf.Tensor (batch_size=-1)

By using batch_size=-1, you can load the full dataset in a single batch.

tfds.load will return a dict (tuple with as_supervised=True) of tf.Tensor (np.array with tfds.as_numpy).

Be careful that your dataset can fit in memory, and that all examples have the same shape.

image, label = tfds.as_numpy(tfds.load(
    'mnist',
    split='test', 
    batch_size=-1, 
    as_supervised=True,
))

print(type(image), image.shape)
<class 'numpy.ndarray'> (10000, 28, 28, 1)

Build end-to-end pipeline

To go further, you can look:

Visualization

tfds.as_dataframe

tf.data.Dataset objects can be converted to pandas.DataFrame with tfds.as_dataframe to be visualized on Colab.

  • Add the tfds.core.DatasetInfo as second argument of tfds.as_dataframe to visualize images, audio, texts, videos,...
  • Use ds.take(x) to only display the first x examples. pandas.DataFrame will load the full dataset in-memory, and can be very expensive to display.
ds, info = tfds.load('mnist', split='train', with_info=True)

tfds.as_dataframe(ds.take(4), info)

tfds.show_examples

For image with tfds.show_examples (only image datasets supported now):

ds, info = tfds.load('mnist', split='train', with_info=True)

fig = tfds.show_examples(ds, info)

png

Access the dataset metadata

All builders include a tfds.core.DatasetInfo object containing the dataset metadata.

It can be accessed through:

ds, info = tfds.load('mnist', with_info=True)
builder = tfds.builder('mnist')
info = builder.info

The dataset info contains additional informations about the dataset (version, citation, homepage, description,...).

print(info)
tfds.core.DatasetInfo(
    name='mnist',
    version=3.0.1,
    description='The MNIST database of handwritten digits.',
    homepage='http://yann.lecun.com/exdb/mnist/',
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
    redistribution_info=,
)


Features metadata (label names, image shape,...)

Access the tfds.features.FeatureDict:

info.features
FeaturesDict({
    'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
    'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
})

Number of classes, label names:

print(info.features["label"].num_classes)
print(info.features["label"].names)
print(info.features["label"].int2str(7))  # Human readable version (8 -> 'cat')
print(info.features["label"].str2int('7'))
10
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
7
7

Shapes, dtypes:

print(info.features.shape)
print(info.features.dtype)
print(info.features['image'].shape)
print(info.features['image'].dtype)
{'image': (28, 28, 1), 'label': ()}
{'image': tf.uint8, 'label': tf.int64}
(28, 28, 1)
<dtype: 'uint8'>

Split metadata (e.g. split names, number of examples,...)

Access the tfds.core.SplitDict:

print(info.splits)
{'test': <tfds.core.SplitInfo num_examples=10000>, 'train': <tfds.core.SplitInfo num_examples=60000>}

Available splits:

print(list(info.splits.keys()))
['test', 'train']

Get info on individual split:

print(info.splits['train'].num_examples)
print(info.splits['train'].filenames)
print(info.splits['train'].num_shards)
60000
['mnist-train.tfrecord-00000-of-00001']
1

It also works with the subsplit API:

print(info.splits['train[15%:75%]'].num_examples)
print(info.splits['train[15%:75%]'].file_instructions)
36000
[FileInstruction(filename='mnist-train.tfrecord-00000-of-00001', skip=9000, take=36000, num_examples=36000)]

Troubleshooting

Manual download (if download fails)

If download fails for some reason (e.g. offline,...). You can always manually download the data yourself and place it in the manual_dir (defaults to ~/tensorflow_datasets/download/manual/.

To find out which urls to download, look into:

Fixing NonMatchingChecksumError

TFDS ensure determinism by validating the checksums of downloaded urls. If NonMatchingChecksumError is raised, might indicate:

  • The website may be down (e.g. 503 status code). Please check the url.
  • For Google Drive URLs, try again later as Drive sometimes rejects downloads when too many people access the same URL. See bug
  • The original datasets files may have been updated. In this case the TFDS dataset builder should be updated. Please open a new Github issue or PR:
    • Register the new checksums with tfds build --register_checksums
    • Eventually update the dataset generation code.
    • Update the dataset VERSION
    • Update the dataset RELEASE_NOTES: What caused the checksums to change ? Did some examples changed ?
    • Make sure the dataset can still be built.
    • Send us a PR

Citation

If you're using tensorflow-datasets for a paper, please include the following citation, in addition to any citation specific to the used datasets (which can be found in the dataset catalog).

@misc{TFDS,
  title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
  howpublished = {\url{https://www.tensorflow.org/datasets} },
}