TensorFlow.org で表示 | Google Colab で実行 | GitHub でソースを表示 | ノートブックをダウンロード | TF Hub モデルを参照 |
このノートブックは、TFDS のデータセットまたは独自の作物病害検出データセットで TensorFlow Hub の CropNet モデルを微調整する方法を説明しています。
ここでは次を行います。
- TFDS キャッサバデータセットまたは独自データをロードする
- 未知の(負の)例でデータを強化して、より堅牢なモデルを取得する
- データに画像拡張を適用する
- TF Hub から CropNet モデルを読み込んで微調整する
- TFLite モデルをエクスポートし、タスクライブラリ、MLKit、または TFLite を使用してアプリに直接デプロイできるようにする
インポートと依存関係
開始する前に、Model Maker や最新バージョンの TensorFlow データセットなどの必要な依存関係のいくつかをインストールする必要があります。
sudo apt install -q libportaudio2
## image_classifier library requires numpy <= 1.23.5
pip install "numpy<=1.23.5"
pip install --use-deprecated=legacy-resolver tflite-model-maker-nightly
pip install -U tensorflow-datasets
## scann library requires tensorflow < 2.9.0
pip install "tensorflow<2.9.0"
pip install "tensorflow-datasets~=4.8.0" # protobuf>=3.12.2
pip install tensorflow-metadata~=1.10.0 # protobuf>=3.13
## tensorflowjs requires packaging < 20.10
pip install "packaging<20.10"
import matplotlib.pyplot as plt
import os
import seaborn as sns
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.lite.model_maker.core.export_format import ExportFormat
from tensorflow_examples.lite.model_maker.core.task import image_preprocessing
from tflite_model_maker import image_classifier
from tflite_model_maker import ImageClassifierDataLoader
from tflite_model_maker.image_classifier import ModelSpec
2024-01-11 20:32:29.477272: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_addons/utils/tfa_eol_msg.py:23: UserWarning: TensorFlow Addons (TFA) has ended development and introduction of new features. TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024. Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). For more information see: https://github.com/tensorflow/addons/issues/2807 warnings.warn( /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_addons/utils/ensure_tf_install.py:53: UserWarning: Tensorflow Addons supports using Python ops for all Tensorflow versions above or equal to 2.13.0 and strictly below 2.16.0 (nightly versions are not supported). The versions of TensorFlow you are currently using is 2.8.4 and is not supported. Some things might work, some things might not. If you were to encounter a bug, do not file an issue. If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. You can find the compatibility matrix in TensorFlow Addon's readme: https://github.com/tensorflow/addons warnings.warn(
TFDS データセットをロードして微調整する
TFDS から公開されているキャッサバの葉の病害のデータセットを使用してみましょう。
tfds_name = 'cassava'
(ds_train, ds_validation, ds_test), ds_info = tfds.load(
name=tfds_name,
split=['train', 'validation', 'test'],
with_info=True,
as_supervised=True)
TFLITE_NAME_PREFIX = tfds_name
2024-01-11 20:32:33.037914: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory 2024-01-11 20:32:33.038041: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or directory 2024-01-11 20:32:33.038121: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublasLt.so.11'; dlerror: libcublasLt.so.11: cannot open shared object file: No such file or directory 2024-01-11 20:32:33.038198: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcufft.so.10'; dlerror: libcufft.so.10: cannot open shared object file: No such file or directory 2024-01-11 20:32:33.094283: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusparse.so.11'; dlerror: libcusparse.so.11: cannot open shared object file: No such file or directory 2024-01-11 20:32:33.094511: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices...
または、独自データを読み込んで微調整する
TFDS データセットを使用する代わりに、独自のデータでトレーニングすることもできます。このコードスニペットは、独自のカスタムデータセットをロードする方法を示しています。サポートされているデータの構造については、このリンクをご覧ください。ここでは、公開されているキャッサバの葉の病害のデータセットを使用した例を示します。
# data_root_dir = tf.keras.utils.get_file(
# 'cassavaleafdata.zip',
# 'https://storage.googleapis.com/emcassavadata/cassavaleafdata.zip',
# extract=True)
# data_root_dir = os.path.splitext(data_root_dir)[0] # Remove the .zip extension
# builder = tfds.ImageFolder(data_root_dir)
# ds_info = builder.info
# ds_train = builder.as_dataset(split='train', as_supervised=True)
# ds_validation = builder.as_dataset(split='validation', as_supervised=True)
# ds_test = builder.as_dataset(split='test', as_supervised=True)
train split からのサンプルを視覚化する
画像サンプルとそのラベルのクラス ID とクラス名を含むデータセットのいくつかの例を見てみましょう。
_ = tfds.show_examples(ds_train, ds_info)
TFDS データセットから未知の例として使用する画像を追加する
未知の(負の)例をトレーニングデータセットに追加し、それらに新しい未知のクラスラベル番号を割り当てます。目標は、実際に(たとえばフィールドで)使用される際に、予期しないものが見つかった場合に「未知」を予測するオプションを持つモデルを作成することです。
以下に、追加の未知の画像をサンプリングするために使用されるデータセットのリストを示します。多様性を高めるために、3 つの完全に異なるデータセットが含まれています。それらの 1 つは豆の葉の病害のデータセットであるため、モデルはキャッサバ以外の罹病植物にさらされています。
UNKNOWN_TFDS_DATASETS = [{
'tfds_name': 'imagenet_v2/matched-frequency',
'train_split': 'test[:80%]',
'test_split': 'test[80%:]',
'num_examples_ratio_to_normal': 1.0,
}, {
'tfds_name': 'oxford_flowers102',
'train_split': 'train',
'test_split': 'test',
'num_examples_ratio_to_normal': 1.0,
}, {
'tfds_name': 'beans',
'train_split': 'train',
'test_split': 'test',
'num_examples_ratio_to_normal': 1.0,
}]
UNKNOWN データセットも TFDS からロードされます。
# Load unknown datasets.
weights = [
spec['num_examples_ratio_to_normal'] for spec in UNKNOWN_TFDS_DATASETS
]
num_unknown_train_examples = sum(
int(w * ds_train.cardinality().numpy()) for w in weights)
ds_unknown_train = tf.data.Dataset.sample_from_datasets([
tfds.load(
name=spec['tfds_name'], split=spec['train_split'],
as_supervised=True).repeat(-1) for spec in UNKNOWN_TFDS_DATASETS
], weights).take(num_unknown_train_examples)
ds_unknown_train = ds_unknown_train.apply(
tf.data.experimental.assert_cardinality(num_unknown_train_examples))
ds_unknown_tests = [
tfds.load(
name=spec['tfds_name'], split=spec['test_split'], as_supervised=True)
for spec in UNKNOWN_TFDS_DATASETS
]
ds_unknown_test = ds_unknown_tests[0]
for ds in ds_unknown_tests[1:]:
ds_unknown_test = ds_unknown_test.concatenate(ds)
# All examples from the unknown datasets will get a new class label number.
num_normal_classes = len(ds_info.features['label'].names)
unknown_label_value = tf.convert_to_tensor(num_normal_classes, tf.int64)
ds_unknown_train = ds_unknown_train.map(lambda image, _:
(image, unknown_label_value))
ds_unknown_test = ds_unknown_test.map(lambda image, _:
(image, unknown_label_value))
# Merge the normal train dataset with the unknown train dataset.
weights = [
ds_train.cardinality().numpy(),
ds_unknown_train.cardinality().numpy()
]
ds_train_with_unknown = tf.data.Dataset.sample_from_datasets(
[ds_train, ds_unknown_train], [float(w) for w in weights])
ds_train_with_unknown = ds_train_with_unknown.apply(
tf.data.experimental.assert_cardinality(sum(weights)))
print((f"Added {ds_unknown_train.cardinality().numpy()} negative examples."
f"Training dataset has now {ds_train_with_unknown.cardinality().numpy()}"
' examples in total.'))
Added 16968 negative examples.Training dataset has now 22624 examples in total.
拡張を適用する
すべての画像に対して、それらをより多様化するために、次の点での変更など、いくつかの拡張を適用します。
- 明るさ
- コントラスト
- 彩度
- 色合い
- クロップ
これらのタイプの拡張は、モデルを画像入力の変動に対してより堅牢にするのに役立ちます。
def random_crop_and_random_augmentations_fn(image):
# preprocess_for_train does random crop and resize internally.
image = image_preprocessing.preprocess_for_train(image)
image = tf.image.random_brightness(image, 0.2)
image = tf.image.random_contrast(image, 0.5, 2.0)
image = tf.image.random_saturation(image, 0.75, 1.25)
image = tf.image.random_hue(image, 0.1)
return image
def random_crop_fn(image):
# preprocess_for_train does random crop and resize internally.
image = image_preprocessing.preprocess_for_train(image)
return image
def resize_and_center_crop_fn(image):
image = tf.image.resize(image, (256, 256))
image = image[16:240, 16:240]
return image
no_augment_fn = lambda image: image
train_augment_fn = lambda image, label: (
random_crop_and_random_augmentations_fn(image), label)
eval_augment_fn = lambda image, label: (resize_and_center_crop_fn(image), label)
拡張を適用するには、Dataset クラスの map
メソッドを使用します。
ds_train_with_unknown = ds_train_with_unknown.map(train_augment_fn)
ds_validation = ds_validation.map(eval_augment_fn)
ds_test = ds_test.map(eval_augment_fn)
ds_unknown_test = ds_unknown_test.map(eval_augment_fn)
INFO:tensorflow:Use default resize_bicubic. INFO:tensorflow:Use default resize_bicubic. INFO:tensorflow:Use customized resize method bilinear INFO:tensorflow:Use customized resize method bilinear
データを Model Maker に適した形式にラップする
これらのデータセットを Model Maker で使用するには、ImageClassifierDataLoader クラスに含まれている必要があります。
label_names = ds_info.features['label'].names + ['UNKNOWN']
train_data = ImageClassifierDataLoader(ds_train_with_unknown,
ds_train_with_unknown.cardinality(),
label_names)
validation_data = ImageClassifierDataLoader(ds_validation,
ds_validation.cardinality(),
label_names)
test_data = ImageClassifierDataLoader(ds_test, ds_test.cardinality(),
label_names)
unknown_test_data = ImageClassifierDataLoader(ds_unknown_test,
ds_unknown_test.cardinality(),
label_names)
トレーニングを実行する
TensorFlow Hub には、転移学習に利用できる複数のモデルがあります。
ここでは 1 つを選択でき、引き続き他のものを試して、より良い結果を得るようにすることもできます。
さらに多くのモデルを試してみたい場合は、このコレクションからモデルを追加できます。
Choose a base model
model_name = 'mobilenet_v3_large_100_224'
map_model_name = {
'cropnet_cassava':
'https://tfhub.dev/google/cropnet/feature_vector/cassava_disease_V1/1',
'cropnet_concat':
'https://tfhub.dev/google/cropnet/feature_vector/concat/1',
'cropnet_imagenet':
'https://tfhub.dev/google/cropnet/feature_vector/imagenet/1',
'mobilenet_v3_large_100_224':
'https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/feature_vector/5',
}
model_handle = map_model_name[model_name]
モデルを微調整するには、Model Maker を使用します。これにより、モデルのトレーニング後にモデルが TFLite に変換されるため、ソリューション全体が簡単になります。
Model Maker は、この変換を可能な限り最良のものにし、後でモデルをデバイスに簡単に展開するために必要なすべての情報を提供します。
モデル仕様は、使用する基本モデルを Model Maker に指示する方法です。
image_model_spec = ModelSpec(uri=model_handle)
ここでの重要な詳細の 1 つは、 train_whole_model
を設定することで、トレーニング中にベースモデルが微調整されます。これによりプロセスは遅くなりますが、最終的なモデルの精度は高くなります。 shuffle
を設定すると、モデルがランダムにシャッフルされた順序でデータを確認できるようになります。これは、モデル学習のベストプラクティスです。
model = image_classifier.create(
train_data,
model_spec=image_model_spec,
batch_size=128,
learning_rate=0.03,
epochs=5,
shuffle=True,
train_whole_model=True,
validation_data=validation_data)
INFO:tensorflow:Retraining the models... INFO:tensorflow:Retraining the models... Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= hub_keras_layer_v1v2 (HubKe (None, 1280) 4226432 rasLayerV1V2) dropout (Dropout) (None, 1280) 0 dense (Dense) (None, 6) 7686 ================================================================= Total params: 4,234,118 Trainable params: 4,209,718 Non-trainable params: 24,400 _________________________________________________________________ None Epoch 1/5 176/176 [==============================] - 531s 3s/step - loss: 0.8837 - accuracy: 0.9174 - val_loss: 1.1820 - val_accuracy: 0.7787 Epoch 2/5 176/176 [==============================] - 512s 3s/step - loss: 0.7951 - accuracy: 0.9527 - val_loss: 1.0481 - val_accuracy: 0.8332 Epoch 3/5 176/176 [==============================] - 511s 3s/step - loss: 0.7746 - accuracy: 0.9581 - val_loss: 1.0474 - val_accuracy: 0.8311 Epoch 4/5 176/176 [==============================] - 510s 3s/step - loss: 0.7616 - accuracy: 0.9641 - val_loss: 1.0096 - val_accuracy: 0.8460 Epoch 5/5 176/176 [==============================] - 509s 3s/step - loss: 0.7565 - accuracy: 0.9647 - val_loss: 0.9942 - val_accuracy: 0.8555
test split でモデルを評価する
model.evaluate(test_data)
59/59 [==============================] - 7s 109ms/step - loss: 0.9798 - accuracy: 0.8695 [0.9798129796981812, 0.8694960474967957]
微調整されたモデルをさらによく理解するには、混同行列を分析することをお勧めします。これは、あるクラスが別のクラスとして予測される頻度を示します。
def predict_class_label_number(dataset):
"""Runs inference and returns predictions as class label numbers."""
rev_label_names = {l: i for i, l in enumerate(label_names)}
return [
rev_label_names[o[0][0]]
for o in model.predict_top_k(dataset, batch_size=128)
]
def show_confusion_matrix(cm, labels):
plt.figure(figsize=(10, 8))
sns.heatmap(cm, xticklabels=labels, yticklabels=labels,
annot=True, fmt='g')
plt.xlabel('Prediction')
plt.ylabel('Label')
plt.show()
confusion_mtx = tf.math.confusion_matrix(
list(ds_test.map(lambda x, y: y)),
predict_class_label_number(test_data),
num_classes=len(label_names))
show_confusion_matrix(confusion_mtx, label_names)
未知のテストデータでモデルを評価する
この評価では、モデルの精度はほぼ 1 であると予想されます。モデルがテストされるすべての画像は通常のデータセットに関連していないため、モデルは「未知の」クラスラベルを予測すると予想されます。
model.evaluate(unknown_test_data)
259/259 [==============================] - 31s 113ms/step - loss: 0.6768 - accuracy: 0.9999 [0.6767985224723816, 0.9998791813850403]
混同行列を印刷します。
unknown_confusion_mtx = tf.math.confusion_matrix(
list(ds_unknown_test.map(lambda x, y: y)),
predict_class_label_number(unknown_test_data),
num_classes=len(label_names))
show_confusion_matrix(unknown_confusion_mtx, label_names)
モデルを TFLite および SavedModel としてエクスポートする
これで、トレーニング済みモデルを TFLite および SavedModel 形式でエクスポートして、デバイスにデプロイし、TensorFlow で推論に使用できるようになりました。
tflite_filename = f'{TFLITE_NAME_PREFIX}_model_{model_name}.tflite'
model.export(export_dir='.', tflite_filename=tflite_filename)
2024-01-11 21:17:15.976215: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpxxl8kj6b/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpxxl8kj6b/assets 2024-01-11 21:17:24.591512: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices... /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:746: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn("Statistics for quantized inputs were expected, but not " 2024-01-11 21:17:26.069851: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:357] Ignored output_format. 2024-01-11 21:17:26.069899: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:360] Ignored drop_control_dependency. INFO:tensorflow:Label file is inside the TFLite model with metadata. fully_quantize: 0, inference_type: 6, input_inference_type: 3, output_inference_type: 3 INFO:tensorflow:Label file is inside the TFLite model with metadata. INFO:tensorflow:Saving labels in /tmpfs/tmp/tmpe4byyqnx/labels.txt INFO:tensorflow:Saving labels in /tmpfs/tmp/tmpe4byyqnx/labels.txt INFO:tensorflow:TensorFlow Lite model exported successfully: ./cassava_model_mobilenet_v3_large_100_224.tflite INFO:tensorflow:TensorFlow Lite model exported successfully: ./cassava_model_mobilenet_v3_large_100_224.tflite
# Export saved model version.
model.export(export_dir='.', export_format=ExportFormat.SAVED_MODEL)
INFO:tensorflow:Assets written to: ./saved_model/assets INFO:tensorflow:Assets written to: ./saved_model/assets
次のステップ
トレーニングしたばかりのモデルは、モバイルデバイスで使用でき、フィールドに展開することもできます。
モデルをダウンロードするには、colab の左側にある [ファイル] メニューのフォルダーアイコンをクリックして、ダウンロードオプションを選択します。
ここで使用されているものと同じ手法は、ユースケースや他のタイプの画像分類タスクにより適している可能性がある他の植物病害タスクに適用できます。フォローアップして Android アプリにデプロイする場合は、この Android クイックスタートガイド を続行できます。