使用视觉注意力生成图像描述

<style> td { text-align: center; } th { text-align: center; } </style>

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

给定一个类似以下示例的图像,我们的目标是生成一个类似“一名正在冲浪的冲浪者”的描述。

一个冲浪的人,来自 Wikimedia

此处使用的模型架构的灵感来自 Show, Attend and Tell: Neural Image Caption Generation with Visual Attention,但已更新为使用 2 层 Transformer 解码器。要充分利用本教程,您应该对文本生成seq2seq 模型和注意力Transformer 有一定的经验。

本教程中构建的模型架构如下所示。从图像中提取特征,并传递到 Transformer 解码器的交叉注意力层。

模型架构

Transformer 解码器主要由注意力层构建。它使用自注意力处理正在生成的序列,并使用交叉注意力处理图像。

通过检查交叉注意力层的注意力权重,您将看到模型在生成单词时正在查看图像的哪些部分。

Prediction

此笔记本是一个端到端示例。当您运行此笔记本时,它会下载数据集、提取和缓存图像特征,并训练解码器模型。随后,它会使用该模型在新的图像上生成描述。

安装

apt install --allow-change-held-packages libcudnn8=8.1.0.77-1+cuda11.2
pip uninstall -y tensorflow estimator keras
pip install -U tensorflow_text tensorflow tensorflow_datasets
pip install einops

本教程使用大量导入,主要用于加载数据集。

[可选] 数据处理

本部分下载描述数据集并为训练做准备。它将输入文本词例化,并缓存通过预训练的特征提取程序模型运行所有图像的结果。理解本部分中的所有内容并不是非常重要。

选择数据集

本教程旨在提供数据集的选择。Flickr8kConceptual Captions 数据集的一小部分。这两个数据集需要从头开始下载和转换,但是将教程转换为使用 TensorFlow 数据集中可用的描述数据集(Coco Captions 和完整的 Conceptual Captions)并不难。

Flickr8k

def flickr8k(path='flickr8k'):
  path = pathlib.Path(path)

  if len(list(path.rglob('*'))) < 16197:
    tf.keras.utils.get_file(
        origin='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip',
        cache_dir='.',
        cache_subdir=path,
        extract=True)
    tf.keras.utils.get_file(
        origin='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip',
        cache_dir='.',
        cache_subdir=path,
        extract=True)

  captions = (path/"Flickr8k.token.txt").read_text().splitlines()
  captions = (line.split('\t') for line in captions)
  captions = ((fname.split('#')[0], caption) for (fname, caption) in captions)

  cap_dict = collections.defaultdict(list)
  for fname, cap in captions:
    cap_dict[fname].append(cap)

  train_files = (path/'Flickr_8k.trainImages.txt').read_text().splitlines()
  train_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in train_files]

  test_files = (path/'Flickr_8k.testImages.txt').read_text().splitlines()
  test_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in test_files]

  train_ds = tf.data.experimental.from_list(train_captions)
  test_ds = tf.data.experimental.from_list(test_captions)

  return train_ds, test_ds

Conceptual Captions

def conceptual_captions(*, data_dir="conceptual_captions", num_train, num_val):
  def iter_index(index_path):
    with open(index_path) as f:
      for line in f:
        caption, url = line.strip().split('\t')
        yield caption, url

  def download_image_urls(data_dir, urls):
    ex = concurrent.futures.ThreadPoolExecutor(max_workers=100)
    def save_image(url):
      hash = hashlib.sha1(url.encode())
      # Name the files after the hash of the URL.
      file_path = data_dir/f'{hash.hexdigest()}.jpeg'
      if file_path.exists():
        # Only download each file once.
        return file_path

      try:
        result = requests.get(url, timeout=5)
      except Exception:
        file_path = None
      else:
        file_path.write_bytes(result.content)
      return file_path

    result = []
    out_paths = ex.map(save_image, urls)
    for file_path in tqdm.tqdm(out_paths, total=len(urls)):
      result.append(file_path)

    return result

  def ds_from_index_file(index_path, data_dir, count):
    data_dir.mkdir(exist_ok=True)
    index = list(itertools.islice(iter_index(index_path), count))
    captions = [caption for caption, url in index]
    urls = [url for caption, url in index]

    paths = download_image_urls(data_dir, urls)

    new_captions = []
    new_paths = []
    for cap, path in zip(captions, paths):
      if path is None:
        # Download failed, so skip this pair.
        continue
      new_captions.append(cap)
      new_paths.append(path)

    new_paths = [str(p) for p in new_paths]

    ds = tf.data.Dataset.from_tensor_slices((new_paths, new_captions))
    ds = ds.map(lambda path,cap: (path, cap[tf.newaxis])) # 1 caption per image
    return ds

  data_dir = pathlib.Path(data_dir)
  train_index_path = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/gcc-data/Train/GCC-training.tsv',
    cache_subdir=data_dir,
    cache_dir='.')

  val_index_path = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/gcc-data/Validation/GCC-1.1.0-Validation.tsv',
    cache_subdir=data_dir,
    cache_dir='.')

  train_raw = ds_from_index_file(train_index_path, data_dir=data_dir/'train', count=num_train)
  test_raw = ds_from_index_file(val_index_path, data_dir=data_dir/'val', count=num_val)

  return train_raw, test_raw

下载数据集

Flickr8k 是一个不错的选择,因为它每个图像包含 5 个描述,下载更少,数据更多。

choose = 'flickr8k'

if choose == 'flickr8k':
  train_raw, test_raw = flickr8k()
else:
  train_raw, test_raw = conceptual_captions(num_train=10000, num_val=5000)

上面两个数据集的加载程序都返回包含 (image_path, captions) 对的 tf.data.Dataset。Flickr8k 数据集每个图像包含 5 个描述,而 Conceptual Captions 有 1 个:

train_raw.element_spec
for ex_path, ex_captions in train_raw.take(1):
  print(ex_path)
  print(ex_captions)

图像特征提取程序

您将使用图像模型(在 imagenet 上预训练)从每个图像中提取特征。该模型被训练为图像分类器,但设置 include_top=False 会返回没有最终分类层的模型,因此您可以使用特征映射的最后一层:

IMAGE_SHAPE=(224, 224, 3)
mobilenet = tf.keras.applications.MobileNetV3Small(
    input_shape=IMAGE_SHAPE,
    include_top=False,
    include_preprocessing=True)
mobilenet.trainable=False

下面是一个加载图像并为模型调整大小的函数:

def load_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.io.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMAGE_SHAPE[:-1])
    return img

该模型为输入批次中的每个图像返回一个特征映射:

test_img_batch = load_image(ex_path)[tf.newaxis, :]

print(test_img_batch.shape)
print(mobilenet(test_img_batch).shape)

设置文本分词器/向量化程序

使用 TextVectorization 层将文本描述转换为整数序列,步骤如下:

  • 使用 adapt 迭代所有描述,将描述拆分为字词,并计算最热门字词的词汇表。
  • 通过将每个字词映射到它在词汇表中的索引对所有描述进行词例化。所有输出序列将被填充到长度 50。
  • 创建字词到索引和索引到字词的映射以显示结果。
def standardize(s):
  s = tf.strings.lower(s)
  s = tf.strings.regex_replace(s, f'[{re.escape(string.punctuation)}]', '')
  s = tf.strings.join(['[START]', s, '[END]'], separator=' ')
  return s
# Use the top 5000 words for a vocabulary.
vocabulary_size = 5000
tokenizer = tf.keras.layers.TextVectorization(
    max_tokens=vocabulary_size,
    standardize=standardize,
    ragged=True)
# Learn the vocabulary from the caption data.
tokenizer.adapt(train_raw.map(lambda fp,txt: txt).unbatch().batch(1024))
tokenizer.get_vocabulary()[:10]
t = tokenizer([['a cat in a hat'], ['a robot dog']])
t
# Create mappings for words to indices and indices to words.
word_to_index = tf.keras.layers.StringLookup(
    mask_token="",
    vocabulary=tokenizer.get_vocabulary())
index_to_word = tf.keras.layers.StringLookup(
    mask_token="",
    vocabulary=tokenizer.get_vocabulary(),
    invert=True)
w = index_to_word(t)
w.to_list()
tf.strings.reduce_join(w, separator=' ', axis=-1).numpy()

准备数据集

train_rawtest_raw 数据集包含一对多 (image, captions) 对。

此函数将复制图像,因此描述中有 1:1 的图像:

def match_shapes(images, captions):
  caption_shape = einops.parse_shape(captions, 'b c')
  captions = einops.rearrange(captions, 'b c -> (b c)')
  images = einops.repeat(
      images, 'b ... -> (b c) ...',
      c = caption_shape['c'])
  return images, captions
for ex_paths, ex_captions in train_raw.batch(32).take(1):
  break

print('image paths:', ex_paths.shape)
print('captions:', ex_captions.shape)
print()

ex_paths, ex_captions = match_shapes(images=ex_paths, captions=ex_captions)

print('image_paths:', ex_paths.shape)
print('captions:', ex_captions.shape)

为了与 keras 训练兼容,数据集应包含 (inputs, labels) 对。对于文本生成,词例既是输入又是标签,且移动了一步。此函数会将 (images, texts) 对转换为 ((images, input_tokens), label_tokens) 对:

def prepare_txt(imgs, txts):
  tokens = tokenizer(txts)

  input_tokens = tokens[..., :-1]
  label_tokens = tokens[..., 1:]
  return (imgs, input_tokens), label_tokens

此函数会将运算添加到数据集。步骤如下:

  1. 加载图像(忽略加载失败的图像)。
  2. 复制图像以匹配描述的数量。
  3. image, caption 对执行重排和重新批处理。
  4. 将文本词例化,移动词例并添加 label_tokens
  5. 将文本从 RaggedTensor 表示转换为填充的密集 Tensor 表示。
def prepare_dataset(ds, tokenizer, batch_size=32, shuffle_buffer=1000):
  # Load the images and make batches.
  ds = (ds
        .shuffle(10000)
        .map(lambda path, caption: (load_image(path), caption))
        .apply(tf.data.experimental.ignore_errors())
        .batch(batch_size))

  def to_tensor(inputs, labels):
    (images, in_tok), out_tok = inputs, labels
    return (images, in_tok.to_tensor()), out_tok.to_tensor()

  return (ds
          .map(match_shapes, tf.data.AUTOTUNE)
          .unbatch()
          .shuffle(shuffle_buffer)
          .batch(batch_size)
          .map(prepare_txt, tf.data.AUTOTUNE)
          .map(to_tensor, tf.data.AUTOTUNE)
          )

您可以在模型中安装特征提取程序并在数据集上进行训练,如下所示:

train_ds = prepare_dataset(train_raw, tokenizer)
train_ds.element_spec
test_ds = prepare_dataset(test_raw, tokenizer)
test_ds.element_spec

[可选] 缓存图像特征

由于图像特征提取程序没有更改,并且本教程没有使用图像增强,可以缓存图像特征。文本词例化也是如此。在训练和验证期间,每个周期都可以重新获得设置缓存所需的时间。下面的代码定义了两个函数 (save_datasetload_dataset):

def save_dataset(ds, save_path, image_model, tokenizer, shards=10, batch_size=32):
  # Load the images and make batches.
  ds = (ds
        .map(lambda path, caption: (load_image(path), caption))
        .apply(tf.data.experimental.ignore_errors())
        .batch(batch_size))

  # Run the feature extractor on each batch
  # Don't do this in a .map, because tf.data runs on the CPU. 
  def gen():
    for (images, captions) in tqdm.tqdm(ds): 
      feature_maps = image_model(images)

      feature_maps, captions = match_shapes(feature_maps, captions)
      yield feature_maps, captions

  # Wrap the generator in a new tf.data.Dataset.
  new_ds = tf.data.Dataset.from_generator(
      gen,
      output_signature=(
          tf.TensorSpec(shape=image_model.output_shape),
          tf.TensorSpec(shape=(None,), dtype=tf.string)))

  # Apply the tokenization 
  new_ds = (new_ds
            .map(prepare_txt, tf.data.AUTOTUNE)
            .unbatch()
            .shuffle(1000))

  # Save the dataset into shard files.
  def shard_func(i, item):
    return i % shards
  new_ds.enumerate().save(save_path, shard_func=shard_func)

def load_dataset(save_path, batch_size=32, shuffle=1000, cycle_length=2):
  def custom_reader_func(datasets):
    datasets = datasets.shuffle(1000)
    return datasets.interleave(lambda x: x, cycle_length=cycle_length)

  ds = tf.data.Dataset.load(save_path, reader_func=custom_reader_func)

  def drop_index(i, x):
    return x

  ds = (ds
        .map(drop_index, tf.data.AUTOTUNE)
        .shuffle(shuffle)
        .padded_batch(batch_size)
        .prefetch(tf.data.AUTOTUNE))
  return ds
save_dataset(train_raw, 'train_cache', mobilenet, tokenizer)
save_dataset(test_raw, 'test_cache', mobilenet, tokenizer)