転移学習で花を分類する

TensorFlow.org で表示 Google Colabで実行 GitHub でソースを表示 ノートブックをダウンロード TF Hub モデルを参照

きれいな花を見つけて、何の花かと思ったことはありませんか?そう思ったのはあなただけではありません。写真から花の種類を特定する方法を構築しましょう!

画像の分類については、畳み込みニューラルネットワークという特定の種類のディープニューラルネットワークが特に強力であることが実証されました。ただし、最新の畳み込みニューラルネットワークには数百ものパラメータがあります。ゼロからトレーニングするには大量のラベル付きのトレーニングデータと多大な計算能力(数百時間以上の GPU)が必要ですが、ここでは 3000 枚のラベル付き写真しかなく、あまり時間をかけられないため、多少頭を使う必要があります。

ここではトレーニング済みのネットワーク(約 100 万個の一般的な画像でトレーニング)を用いた転移学習というテクニックを使用して、特徴量を抽出し、花の画像を分類するというタスク向けにさらに新しいレイヤーをトレーニングします。

セットアップ

import collections
import io
import math
import os
import random
from six.moves import urllib

from IPython.display import clear_output, Image, display, HTML

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

import tensorflow_hub as hub

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn.metrics as sk_metrics
import time
2024-01-11 17:56:09.794501: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 17:56:09.794548: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 17:56:09.796039: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/compat/v2_compat.py:108: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.
Instructions for updating:
non-resource variables are not supported in the long term

Flowers データセット

flowers データセットは、可能性のある 5 つのクラスラベルを伴う花の画像で構成されています。

機械学習モデルをトレーニングする際、データを training と test データセットに分割します。training データを使ってモデルをトレーニングしてから、モデルが分析したことのないデータ(test セット)でのモデルのパフォーマンスを評価します。

training と test サンプルをダウンロードし(時間がかかる場合があります)、training と test のセットに分割しましょう。

次の 2 つのセルを実行します。

FLOWERS_DIR = './flower_photos'
TRAIN_FRACTION = 0.8
RANDOM_SEED = 2018


def download_images():
  """If the images aren't already downloaded, save them to FLOWERS_DIR."""
  if not os.path.exists(FLOWERS_DIR):
    DOWNLOAD_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
    print('Downloading flower images from %s...' % DOWNLOAD_URL)
    urllib.request.urlretrieve(DOWNLOAD_URL, 'flower_photos.tgz')
    !tar xfz flower_photos.tgz
  print('Flower photos are located in %s' % FLOWERS_DIR)


def make_train_and_test_sets():
  """Split the data into train and test sets and get the label classes."""
  train_examples, test_examples = [], []
  shuffler = random.Random(RANDOM_SEED)
  is_root = True
  for (dirname, subdirs, filenames) in tf.gfile.Walk(FLOWERS_DIR):
    # The root directory gives us the classes
    if is_root:
      subdirs = sorted(subdirs)
      classes = collections.OrderedDict(enumerate(subdirs))
      label_to_class = dict([(x, i) for i, x in enumerate(subdirs)])
      is_root = False
    # The sub directories give us the image files for training.
    else:
      filenames.sort()
      shuffler.shuffle(filenames)
      full_filenames = [os.path.join(dirname, f) for f in filenames]
      label = dirname.split('/')[-1]
      label_class = label_to_class[label]
      # An example is the image file and it's label class.
      examples = list(zip(full_filenames, [label_class] * len(filenames)))
      num_train = int(len(filenames) * TRAIN_FRACTION)
      train_examples.extend(examples[:num_train])
      test_examples.extend(examples[num_train:])

  shuffler.shuffle(train_examples)
  shuffler.shuffle(test_examples)
  return train_examples, test_examples, classes
# Download the images and split the images into train and test sets.
download_images()
TRAIN_EXAMPLES, TEST_EXAMPLES, CLASSES = make_train_and_test_sets()
NUM_CLASSES = len(CLASSES)

print('\nThe dataset has %d label classes: %s' % (NUM_CLASSES, CLASSES.values()))
print('There are %d training images' % len(TRAIN_EXAMPLES))
print('there are %d test images' % len(TEST_EXAMPLES))
Downloading flower images from http://download.tensorflow.org/example_images/flower_photos.tgz...
Flower photos are located in ./flower_photos

The dataset has 5 label classes: odict_values(['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'])
There are 2934 training images
there are 736 test images

データを確認する

flowers データセットは、ラベル付きの花の画像のサンプルで構成されています。各サンプルには、JPEG 形式の花の画像と、花の種類を示すクラスラベルが含まれます。いくつかの画像をラベルとともに表示してみましょう。

Show some labeled images

png

モデルを構築する

TF-Hub の画像特徴量ベクトルモジュールを読み込み、それに線形分類器をスタックし、トレーニングと検証の演算を追加することにします。次のセルは、モデルとそのトレーニングを説明する TF グラフを構築しますが、トレーニングを実行しません(この後のステップで行います)。

LEARNING_RATE = 0.01

tf.reset_default_graph()

# Load a pre-trained TF-Hub module for extracting features from images. We've
# chosen this particular module for speed, but many other choices are available.
image_module = hub.Module('https://tfhub.dev/google/imagenet/mobilenet_v2_035_128/feature_vector/2')

# Preprocessing images into tensors with size expected by the image module.
encoded_images = tf.placeholder(tf.string, shape=[None])
image_size = hub.get_expected_image_size(image_module)


def decode_and_resize_image(encoded):
  decoded = tf.image.decode_jpeg(encoded, channels=3)
  decoded = tf.image.convert_image_dtype(decoded, tf.float32)
  return tf.image.resize_images(decoded, image_size)


batch_images = tf.map_fn(decode_and_resize_image, encoded_images, dtype=tf.float32)

# The image module can be applied as a function to extract feature vectors for a
# batch of images.
features = image_module(batch_images)


def create_model(features):
  """Build a model for classification from extracted features."""
  # Currently, the model is just a single linear layer. You can try to add
  # another layer, but be careful... two linear layers (when activation=None)
  # are equivalent to a single linear layer. You can create a nonlinear layer
  # like this:
  # layer = tf.layers.dense(inputs=..., units=..., activation=tf.nn.relu)
  layer = tf.layers.dense(inputs=features, units=NUM_CLASSES, activation=None)
  return layer


# For each class (kind of flower), the model outputs some real number as a score
# how much the input resembles this class. This vector of numbers is often
# called the "logits".
logits = create_model(features)
labels = tf.placeholder(tf.float32, [None, NUM_CLASSES])

# Mathematically, a good way to measure how much the predicted probabilities
# diverge from the truth is the "cross-entropy" between the two probability
# distributions. For numerical stability, this is best done directly from the
# logits, not the probabilities extracted from them.
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels)
cross_entropy_mean = tf.reduce_mean(cross_entropy)

# Let's add an optimizer so we can train the network.
optimizer = tf.train.GradientDescentOptimizer(learning_rate=LEARNING_RATE)
train_op = optimizer.minimize(loss=cross_entropy_mean)

# The "softmax" function transforms the logits vector into a vector of
# probabilities: non-negative numbers that sum up to one, and the i-th number
# says how likely the input comes from class i.
probabilities = tf.nn.softmax(logits)

# We choose the highest one as the predicted class.
prediction = tf.argmax(probabilities, 1)
correct_prediction = tf.equal(prediction, tf.argmax(labels, 1))

# The accuracy will allow us to eval on our test set. 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19573/2879154528.py:20: calling map_fn (from tensorflow.python.ops.map_fn) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Use fn_output_signature instead
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19573/2879154528.py:20: calling map_fn (from tensorflow.python.ops.map_fn) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Use fn_output_signature instead
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
/tmpfs/tmp/ipykernel_19573/2879154528.py:34: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  layer = tf.layers.dense(inputs=features, units=NUM_CLASSES, activation=None)

ネットワークをトレーニングする

ラベルの構築が完了したので、test セットを使ってトレーニングし、そのパフォーマンスを確認することにしましょう。

# How long will we train the network (number of batches).
NUM_TRAIN_STEPS = 100
# How many training examples we use in each step.
TRAIN_BATCH_SIZE = 10
# How often to evaluate the model performance.
EVAL_EVERY = 10

def get_batch(batch_size=None, test=False):
  """Get a random batch of examples."""
  examples = TEST_EXAMPLES if test else TRAIN_EXAMPLES
  batch_examples = random.sample(examples, batch_size) if batch_size else examples
  return batch_examples

def get_images_and_labels(batch_examples):
  images = [get_encoded_image(e) for e in batch_examples]
  one_hot_labels = [get_label_one_hot(e) for e in batch_examples]
  return images, one_hot_labels

def get_label_one_hot(example):
  """Get the one hot encoding vector for the example."""
  one_hot_vector = np.zeros(NUM_CLASSES)
  np.put(one_hot_vector, get_label(example), 1)
  return one_hot_vector

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for i in range(NUM_TRAIN_STEPS):
    # Get a random batch of training examples.
    train_batch = get_batch(batch_size=TRAIN_BATCH_SIZE)
    batch_images, batch_labels = get_images_and_labels(train_batch)
    # Run the train_op to train the model.
    train_loss, _, train_accuracy = sess.run(
        [cross_entropy_mean, train_op, accuracy],
        feed_dict={encoded_images: batch_images, labels: batch_labels})
    is_final_step = (i == (NUM_TRAIN_STEPS - 1))
    if i % EVAL_EVERY == 0 or is_final_step:
      # Get a batch of test examples.
      test_batch = get_batch(batch_size=None, test=True)
      batch_images, batch_labels = get_images_and_labels(test_batch)
      # Evaluate how well our model performs on the test set.
      test_loss, test_accuracy, test_prediction, correct_predicate = sess.run(
        [cross_entropy_mean, accuracy, prediction, correct_prediction],
        feed_dict={encoded_images: batch_images, labels: batch_labels})
      print('Test accuracy at step %s: %.2f%%' % (i, (test_accuracy * 100)))
Test accuracy at step 0: 24.46%
Test accuracy at step 10: 50.41%
Test accuracy at step 20: 61.96%
Test accuracy at step 30: 69.02%
Test accuracy at step 40: 69.84%
Test accuracy at step 50: 72.83%
Test accuracy at step 60: 75.54%
Test accuracy at step 70: 76.63%
Test accuracy at step 80: 76.63%
Test accuracy at step 90: 78.53%
Test accuracy at step 99: 78.80%
def show_confusion_matrix(test_labels, predictions):
  """Compute confusion matrix and normalize."""
  confusion = sk_metrics.confusion_matrix(
    np.argmax(test_labels, axis=1), predictions)
  confusion_normalized = confusion.astype("float") / confusion.sum(axis=1)
  axis_labels = list(CLASSES.values())
  ax = sns.heatmap(
      confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels,
      cmap='Blues', annot=True, fmt='.2f', square=True)
  plt.title("Confusion matrix")
  plt.ylabel("True label")
  plt.xlabel("Predicted label")

show_confusion_matrix(batch_labels, test_prediction)

png

不正確な予測

モデルが間違えた test サンプルをもう少し詳しく見てみましょう。

  • test セットに誤ってラベル付けされたサンプルがありますか?
  • test セットに、実際には花の画像でないものなど、不適切なデータがありますか?
  • モデルがミスした原因がわかる画像はありますか?
incorrect = [
    (example, CLASSES[prediction])
    for example, prediction, is_correct in zip(test_batch, test_prediction, correct_predicate)
    if not is_correct
]
display_images(
  [(get_image(example), "prediction: {0}\nlabel:{1}".format(incorrect_prediction, get_class(example)))
   for (example, incorrect_prediction) in incorrect[:20]])

png

演習: モデルを改善しましょう!

ベースラインモデルはトレーニングしたので、正確性を高められるようにそれを改善してみましょう。(変更を適用するたびに、セルを実行し直す必要があります。)

演習 1: 異なる画像モデルを試す

TF-Hub では、異なる画像モデルを簡単に試すことができます。hub.Module() 内の "https://tfhub.dev/google/imagenet/mobilenet_v2_050_128/feature_vector/2" ハンドルを別のモジュールのハンドルに置き換えて、すべてのコードを実行し直すだけで可能です。利用可能なすべての画像モジュールは、tfhub.dev にあります。

ほかの MobileNet V2 モジュールを使用すると良いでしょう。MobileNet モジュールを含む多数のモジュールは、100 万個を超える画像と 1000 個のクラスを含む ImageNet データセットを使ってトレーニングされたものです。ネットワークアーキテクチャの選択によって、速度と分類精度にトレードオフが生じます。MobileNet または NASNet Mobile は高速で小型ですが、Inception や ResNet などのより従来的なアーキテクチャは、精度を重視して設計されています。

より大規模な Inception V3 アーキテクチャの場合は、目的のタスクにより近い領域で事前にトレーニングされているというメリットを得ることができます。植物や動物の iNaturalist データセットでトレーニングされたモジュールとして提供されています。

演習 2: 非表示レイヤーを追加する

抽出した画像特徴量と線形分類器の間に非表示レイヤーをスタックします(上記の create_model() 内)。100 個などのノードで非線形の非表示レイヤーを作成するには、tf.layers.dense を使用し、ユニットを 100、アクティベーションを tf.nn.relu に設定します。非表示レイヤーのサイズを変更すると、テストの精度に影響しますか?2 つ目の非表示レイヤーを追加すると、精度が改善されますか?

演習 3: ハイパーパラメータを変更する

トレーニングステップ数を変更すると、最終精度が改善されますか?学習速度を変更すると、モデルの収束を加速化できますか?トレーニングのバッチサイズによって、モデルのパフォーマンスが変化しますか?

演習 4: 異なるオプティマイザを試す

基本の GradientDescentOptimizer を、AdagradOptimizer などのより高度なオプティマイザに置き換えてみましょう。モデルのトレーニングに変化はありますか?さまざまな最適化アルゴリズムのメリットの詳細については、こちらの記事をご覧ください。

今後の学習

このチュートリアルのさらに高度なバージョンに興味がある方は、TensorBoard を使用したトレーニングの視覚化、画像に歪みを与えてデータセットを拡張する高度なテクニック、および flowers データセットを置き換えて独自のデータセットで画像分類器を学習する方法を説明した TensorFlow 画像の再トレーニングチュートリアルをご覧ください。

tensorflow.org/hub では、TensorFlow についてさらに学習し、TF-Hub API ドキュメントを確認することができます。また、tfhub.dev では、その他の画像特徴量ベクトルモジュールやテキスト埋め込みモジュールなど、利用可能な TensorFlow Hub モジュールを検索することができます。

さらに、Google の Machine Learning Crash Course もご覧ください。機械学習の実用的な導入をテンポよく学習できます。