ビジュアルアテンションを用いた画像キャプショニング

<style> td { text-align: center; } th { text-align: center; } </style>

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

以下の例のような画像が与えられた場合、目標は「波に乗っているサーファー」などのキャプションを生成することです。

サーフィンしている男性(出典: wikimedia

ここで使用されているモデルアーキテクチャは、「Show, Attend and Tell: Neural Image Caption Generation with Visual Attention」からアイデアを得たものですが、2 レイヤー Transformer デコーダを使用するように更新されています。このチュートリアルを最大限に活用するには、テキスト生成seq2seq モデルとアテンション、または transformer の使用経験があるとよいでしょう。

以下は、このチュートリアルに組み込まれるモデルアーキテクチャです。特徴量は画像から抽出され、Transformer デコーダのクロスアテンションレイヤーに渡されています。

モデルアーキテクチャ

Transformer デコーダは主に、アテンションレイヤーから構築されます。セルフアテンションを使用して、生成されるシーケンスを処理し、クロスアテンションを使用して、画像に注意を向けます。

クロスアテンションレイヤーのアテンションの重みを検査することで、モデルが単語を生成する過程で、モデルが画像のどの部分を見ているかを知ることができます。

予測

このノートブックは、エンドツーエンドの例を示します。このノートブックを実行すると、データセットのダウンロード、画像特徴量の抽出とキャッシュ処理、そしてデコーダモデルのトレーニングが行われます。その後で、モデルを使用して、新しい画像のキャプションが生成されるようになります。

セットアップ

apt install --allow-change-held-packages libcudnn8=8.1.0.77-1+cuda11.2
E: Could not open lock file /var/lib/dpkg/lock-frontend - open (13: Permission denied)
E: Unable to acquire the dpkg frontend lock (/var/lib/dpkg/lock-frontend), are you root?
pip uninstall -y tensorflow estimator keras
pip install -U tensorflow_text tensorflow tensorflow_datasets
pip install einops

このチュートリアルでは多数の import を使用します。ほとんどがデータセットの読み込み目的です。

2024-01-11 19:14:21.407842: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 19:14:21.407891: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 19:14:21.409458: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

[オプション] データ処理

このセクションでは、captions データセットをダウンロードして、トレーニングの準備を行います。入力テキストをトークン化し、事前トレーニング済みの特徴量抽出器モデルを通じてすべての画像の実行結果をキャッシュします。このセクションの内容を完全に理解することは重要ではありません。

データセットを選択する

このチュートリアルは、データセットの選択肢を提供するようにセットアップされています。Flickr8k、または Conceptual Captions データセットの小さなスライスのいずれかを選択します。これらはゼロからダウンロードして変換されますが、TensorFlow Datasets で提供されている Coco Captions と完全な Conceptual Captions というキャプションデータセットを使用するようにチュートリアルを変更することは困難ではありません。

Flickr8k

def flickr8k(path='flickr8k'):
  path = pathlib.Path(path)

  if len(list(path.rglob('*'))) < 16197:
    tf.keras.utils.get_file(
        origin='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip',
        cache_dir='.',
        cache_subdir=path,
        extract=True)
    tf.keras.utils.get_file(
        origin='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip',
        cache_dir='.',
        cache_subdir=path,
        extract=True)

  captions = (path/"Flickr8k.token.txt").read_text().splitlines()
  captions = (line.split('\t') for line in captions)
  captions = ((fname.split('#')[0], caption) for (fname, caption) in captions)

  cap_dict = collections.defaultdict(list)
  for fname, cap in captions:
    cap_dict[fname].append(cap)

  train_files = (path/'Flickr_8k.trainImages.txt').read_text().splitlines()
  train_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in train_files]

  test_files = (path/'Flickr_8k.testImages.txt').read_text().splitlines()
  test_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in test_files]

  train_ds = tf.data.experimental.from_list(train_captions)
  test_ds = tf.data.experimental.from_list(test_captions)

  return train_ds, test_ds

Conceptual Captions

def conceptual_captions(*, data_dir="conceptual_captions", num_train, num_val):
  def iter_index(index_path):
    with open(index_path) as f:
      for line in f:
        caption, url = line.strip().split('\t')
        yield caption, url

  def download_image_urls(data_dir, urls):
    ex = concurrent.futures.ThreadPoolExecutor(max_workers=100)
    def save_image(url):
      hash = hashlib.sha1(url.encode())
      # Name the files after the hash of the URL.
      file_path = data_dir/f'{hash.hexdigest()}.jpeg'
      if file_path.exists():
        # Only download each file once.
        return file_path

      try:
        result = requests.get(url, timeout=5)
      except Exception:
        file_path = None
      else:
        file_path.write_bytes(result.content)
      return file_path

    result = []
    out_paths = ex.map(save_image, urls)
    for file_path in tqdm.tqdm(out_paths, total=len(urls)):
      result.append(file_path)

    return result

  def ds_from_index_file(index_path, data_dir, count):
    data_dir.mkdir(exist_ok=True)
    index = list(itertools.islice(iter_index(index_path), count))
    captions = [caption for caption, url in index]
    urls = [url for caption, url in index]

    paths = download_image_urls(data_dir, urls)

    new_captions = []
    new_paths = []
    for cap, path in zip(captions, paths):
      if path is None:
        # Download failed, so skip this pair.
        continue
      new_captions.append(cap)
      new_paths.append(path)

    new_paths = [str(p) for p in new_paths]

    ds = tf.data.Dataset.from_tensor_slices((new_paths, new_captions))
    ds = ds.map(lambda path,cap: (path, cap[tf.newaxis])) # 1 caption per image
    return ds

  data_dir = pathlib.Path(data_dir)
  train_index_path = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/gcc-data/Train/GCC-training.tsv',
    cache_subdir=data_dir,
    cache_dir='.')

  val_index_path = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/gcc-data/Validation/GCC-1.1.0-Validation.tsv',
    cache_subdir=data_dir,
    cache_dir='.')

  train_raw = ds_from_index_file(train_index_path, data_dir=data_dir/'train', count=num_train)
  test_raw = ds_from_index_file(val_index_path, data_dir=data_dir/'val', count=num_val)

  return train_raw, test_raw

データセットをダウンロードする

Flickr8k には 1 つの画像につき 5 つのキャプションが含まれているため、より小さなダウンロードサイズでより多くのデータを得られる良い選択肢と言えます。

choose = 'flickr8k'

if choose == 'flickr8k':
  train_raw, test_raw = flickr8k()
else:
  train_raw, test_raw = conceptual_captions(num_train=10000, num_val=5000)
Downloading data from https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
1115419746/1115419746 [==============================] - 4s 0us/step
Downloading data from https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
2340801/2340801 [==============================] - 0s 0us/step

上記のいずれのデータセットのローダーも、(image_path, captions) ペアを含む tf.data.Dataset を返します。Flickr8k データセットには 1 つの画像につき 5 つのキャプションが含まれているのに対し、Conceptual Captions のキャプションは 1 つです。

train_raw.element_spec
(TensorSpec(shape=(), dtype=tf.string, name=None),
 TensorSpec(shape=(5,), dtype=tf.string, name=None))
for ex_path, ex_captions in train_raw.take(1):
  print(ex_path)
  print(ex_captions)
tf.Tensor(b'flickr8k/Flicker8k_Dataset/2513260012_03d33305cf.jpg', shape=(), dtype=string)
tf.Tensor(
[b'A black dog is running after a white dog in the snow .'
 b'Black dog chasing brown dog through snow'
 b'Two dogs chase each other across the snowy ground .'
 b'Two dogs play together in the snow .'
 b'Two dogs running through a low lying body of water .'], shape=(5,), dtype=string)

画像特徴量抽出器

画像モデル(imagenet で事前トレーニング済み)を使用して各画像から特徴量を抽出します。このモデルは画像分類器としてトレーニングされていますが、設定 include_top=False は最終的な分類レイヤーなしのモデルを返すため、特徴量マップの最後のレイヤーを使用できます。

IMAGE_SHAPE=(224, 224, 3)
mobilenet = tf.keras.applications.MobileNetV3Small(
    input_shape=IMAGE_SHAPE,
    include_top=False,
    include_preprocessing=True)
mobilenet.trainable=False
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v3/weights_mobilenet_v3_small_224_1.0_float_no_top_v2.h5
4334752/4334752 [==============================] - 0s 0us/step

以下は、画像を読み込んでモデルに合わせてサイズを変更する関数です。

def load_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.io.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMAGE_SHAPE[:-1])
    return img

モデルは、入力バッチで各画像の特徴量マップを返します。

test_img_batch = load_image(ex_path)[tf.newaxis, :]

print(test_img_batch.shape)
print(mobilenet(test_img_batch).shape)
(1, 224, 224, 3)
(1, 7, 7, 576)

テキストトークナイザ/ベクタナイザをセットアップする

次の手順で、TextVectorization レイヤーを使用して、テキストキャプションを整数シーケンスに変換します。

  • adapt を使用して、すべてのキャプションをイテレートし、キャプションを単語に分割して、上位の単語の語彙を計算します。
  • 各単語を語彙のインデックスにマッピングして、すべてのキャプションをトークン化します。すべての出力シーケンスは、長さ 50 までパディングされます。
  • 結果を表示するために、単語からインデックスおよびインデックスから単語へのマッピングを作成します。
def standardize(s):
  s = tf.strings.lower(s)
  s = tf.strings.regex_replace(s, f'[{re.escape(string.punctuation)}]', '')
  s = tf.strings.join(['[START]', s, '[END]'], separator=' ')
  return s
# Use the top 5000 words for a vocabulary.
vocabulary_size = 5000
tokenizer = tf.keras.layers.TextVectorization(
    max_tokens=vocabulary_size,
    standardize=standardize,
    ragged=True)
# Learn the vocabulary from the caption data.
tokenizer.adapt(train_raw.map(lambda fp,txt: txt).unbatch().batch(1024))
tokenizer.get_vocabulary()[:10]
['', '[UNK]', 'a', '[START]', '[END]', 'in', 'the', 'on', 'is', 'and']
t = tokenizer([['a cat in a hat'], ['a robot dog']])
t
<tf.RaggedTensor [[3, 2, 655, 5, 2, 97, 4], [3, 2, 1937, 10, 4]]>
# Create mappings for words to indices and indices to words.
word_to_index = tf.keras.layers.StringLookup(
    mask_token="",
    vocabulary=tokenizer.get_vocabulary())
index_to_word = tf.keras.layers.StringLookup(
    mask_token="",
    vocabulary=tokenizer.get_vocabulary(),
    invert=True)
w = index_to_word(t)
w.to_list()
[[b'[START]', b'a', b'cat', b'in', b'a', b'hat', b'[END]'],
 [b'[START]', b'a', b'robot', b'dog', b'[END]']]
tf.strings.reduce_join(w, separator=' ', axis=-1).numpy()
array([b'[START] a cat in a hat [END]', b'[START] a robot dog [END]'],
      dtype=object)

データセットを準備する

train_rawtest_raw データセットには、一対多の (image, captions) ペアが含まれます。

この関数は、画像とキャプションが 1:1 になるように、画像を複製します。

def match_shapes(images, captions):
  caption_shape = einops.parse_shape(captions, 'b c')
  captions = einops.rearrange(captions, 'b c -> (b c)')
  images = einops.repeat(
      images, 'b ... -> (b c) ...',
      c = caption_shape['c'])
  return images, captions
for ex_paths, ex_captions in train_raw.batch(32).take(1):
  break

print('image paths:', ex_paths.shape)
print('captions:', ex_captions.shape)
print()

ex_paths, ex_captions = match_shapes(images=ex_paths, captions=ex_captions)

print('image_paths:', ex_paths.shape)
print('captions:', ex_captions.shape)
image paths: (32,)
captions: (32, 5)

image_paths: (160,)
captions: (160,)

Keras トレーニングと互換性を持たせるため、データセットには (inputs, labels) ペアを含める必要があります。テキスト生成では、トークンは、入力とラベルのいずれでもあり、1 ステップずつシフトされます。この関数は、(images, texts) ペアを ((images, input_tokens), label_tokens) ペアに変換します。

def prepare_txt(imgs, txts):
  tokens = tokenizer(txts)

  input_tokens = tokens[..., :-1]
  label_tokens = tokens[..., 1:]
  return (imgs, input_tokens), label_tokens

この関数は、以下の手順でデータセットに演算を追加します。

  1. 画像を読み込みます(読み込みに失敗する画像は無視されます)。
  2. キャプションの数に合わせて画像を複製します。
  3. image, caption ペアをシャッフルして再バッチ化します。
  4. テキストをトークン化し、トークンをシフトして label_tokens を追加します。
  5. RaggedTensor 表現のテキストをパディング付きの高密度 Tensor 表現に変換します。
def prepare_dataset(ds, tokenizer, batch_size=32, shuffle_buffer=1000):
  # Load the images and make batches.
  ds = (ds
        .shuffle(10000)
        .map(lambda path, caption: (load_image(path), caption))
        .apply(tf.data.experimental.ignore_errors())
        .batch(batch_size))

  def to_tensor(inputs, labels):
    (images, in_tok), out_tok = inputs, labels
    return (images, in_tok.to_tensor()), out_tok.to_tensor()

  return (ds
          .map(match_shapes, tf.data.AUTOTUNE)
          .unbatch()
          .shuffle(shuffle_buffer)
          .batch(batch_size)
          .map(prepare_txt, tf.data.AUTOTUNE)
          .map(to_tensor, tf.data.AUTOTUNE)
          )

以下のようにして、特徴量抽出器をモデルにインストールし、データセットでトレーニングすることが可能です。

train_ds = prepare_dataset(train_raw, tokenizer)
train_ds.element_spec
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_116100/1004139779.py:6: ignore_errors (from tensorflow.python.data.experimental.ops.error_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.ignore_errors` instead.
((TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None),
  TensorSpec(shape=(None, None), dtype=tf.int64, name=None)),
 TensorSpec(shape=(None, None), dtype=tf.int64, name=None))
test_ds = prepare_dataset(test_raw, tokenizer)
test_ds.element_spec
((TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None),
  TensorSpec(shape=(None, None), dtype=tf.int64, name=None)),
 TensorSpec(shape=(None, None), dtype=tf.int64, name=None))

[オプション] 画像特徴量をキャッシュする

画像特徴量抽出器には変化がなく、このチュートリアルでは画像拡張を使用していないため、画像特徴量をキャッシュすることができます。テキストトークン化についても同様です。キャッシュをセットアップするのに時間がかかりますが、その時間は、トレーニングと検証中に、各エポックで取り返すことができます。以下のコードでは、save_datasetload_dataset の 2 つの関数を定義しています。

def save_dataset(ds, save_path, image_model, tokenizer, shards=10, batch_size=32):
  # Load the images and make batches.
  ds = (ds
        .map(lambda path, caption: (load_image(path), caption))
        .apply(tf.data.experimental.ignore_errors())
        .batch(batch_size))

  # Run the feature extractor on each batch
  # Don't do this in a .map, because tf.data runs on the CPU. 
  def gen():
    for (images, captions) in tqdm.tqdm(ds): 
      feature_maps = image_model(images)

      feature_maps, captions = match_shapes(feature_maps, captions)
      yield feature_maps, captions

  # Wrap the generator in a new tf.data.Dataset.
  new_ds = tf.data.Dataset.from_generator(
      gen,
      output_signature=(
          tf.TensorSpec(shape=image_model.output_shape),
          tf.TensorSpec(shape=(None,), dtype=tf.string)))

  # Apply the tokenization 
  new_ds = (new_ds
            .map(prepare_txt, tf.data.AUTOTUNE)
            .unbatch()
            .shuffle(1000))

  # Save the dataset into shard files.
  def shard_func(i, item):
    return i % shards
  new_ds.enumerate().save(save_path, shard_func=shard_func)

def load_dataset(save_path, batch_size=32, shuffle=1000, cycle_length=2):
  def custom_reader_func(datasets):
    datasets = datasets.shuffle(1000)
    return datasets.interleave(lambda x: x, cycle_length=cycle_length)

  ds = tf.data.Dataset.load(save_path, reader_func=custom_reader_func)

  def drop_index(i, x):
    return x

  ds = (ds
        .map(drop_index, tf.data.AUTOTUNE)
        .shuffle(shuffle)
        .padded_batch(batch_size)
        .prefetch(tf.data.AUTOTUNE))
  return ds
save_dataset(train_raw, 'train_cache', mobilenet, tokenizer)
save_dataset(test_raw, 'test_cache', mobilenet, tokenizer)
188it [00:23,  8.02it/s]
32it [00:04,  7.95it/s]