在 TensorFlow.org 查看 | 在 Google Colab 中运行 | 查看上GitHub | 下载笔记本 | 查看 TF Hub 模型 |
此笔记本演示如何使用 TensorFlow Hub 中的 CropNet 木薯病虫害分类器模型。该模型可将木薯叶的图像分为 6 类:细菌性枯萎病、褐条病毒病、绿螨、花叶病、健康或未知。
此 Colab 演示了如何执行以下操作:
- 从 TensorFlow Hub 加载 https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 模型
- 从 TensorFlow Datasets (TFDS) 加载木薯数据集
- 将木薯叶图像分为 4 种不同的木薯病虫害类别、健康或未知。
- 评估分类器的准确率,并查看将模型在应用于域外图像时的鲁棒性。
导入和设置
pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
2022-12-14 21:47:23.846046: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 21:47:23.846144: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 21:47:23.846153: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Helper function for displaying examples
def plot(examples, predictions=None):
# Get the images, labels, and optionally predictions
images = examples['image']
labels = examples['label']
batch_size = len(images)
if predictions is None:
predictions = batch_size * [None]
# Configure the layout of the grid
x = np.ceil(np.sqrt(batch_size))
y = np.ceil(batch_size / x)
fig = plt.figure(figsize=(x * 6, y * 7))
for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
# Render the image
ax = fig.add_subplot(x, y, i+1)
ax.imshow(image, aspect='auto')
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
# Display the label and optionally prediction
x_label = 'Label: ' + name_map[class_names[label]]
if prediction is not None:
x_label = 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
ax.xaxis.label.set_color('green' if label == prediction else 'red')
ax.set_xlabel(x_label)
plt.show()
数据集
让我们从 TFDS 中加载木薯数据集
dataset, info = tfds.load('cassava', with_info=True)
我们来查看数据集信息以了解更多内容,例如描述和引用以及有关可用样本量的信息
info
tfds.core.DatasetInfo( name='cassava', full_name='cassava/0.1.0', description=""" Cassava consists of leaf images for the cassava plant depicting healthy and four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD). Dataset consists of a total of 9430 labelled images. The 9430 labelled images are split into a training set (5656), a test set(1885) and a validation set (1889). The number of images per class are unbalanced with the two disease classes CMD and CBSD having 72% of the images. """, homepage='https://www.kaggle.com/c/cassava-disease/overview', data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0', file_format=tfrecord, download_size=1.26 GiB, dataset_size=Unknown size, features=FeaturesDict({ 'image': Image(shape=(None, None, 3), dtype=tf.uint8), 'image/filename': Text(shape=(), dtype=tf.string), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'test': <SplitInfo num_examples=1885, num_shards=4>, 'train': <SplitInfo num_examples=5656, num_shards=8>, 'validation': <SplitInfo num_examples=1889, num_shards=4>, }, citation="""@misc{mwebaze2019icassava, title={iCassava 2019Fine-Grained Visual Categorization Challenge}, author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira}, year={2019}, eprint={1908.02900}, archivePrefix={arXiv}, primaryClass={cs.CV} }""", )
木薯数据集包含涉及 4 种不同病虫害的木薯叶图像以及健康的木薯叶图像。模型可以预测上述五个类,当模型不确定其预测结果时,会将图像分为第六个类,即“未知”类。
# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']
# Map the class names to human readable names
name_map = dict(
cmd='Mosaic Disease',
cbb='Bacterial Blight',
cgm='Green Mite',
cbsd='Brown Streak Disease',
healthy='Healthy',
unknown='Unknown')
print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes: ['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown'] ['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']
将数据馈送至模型之前,我们需要进行一些预处理。模型接受大小为 224 x 224,且 RGB 通道值范围为 [0, 1] 的图像。让我们归一化图像并调整图像大小。
def preprocess_fn(data):
image = data['image']
# Normalize [0, 255] to [0, 1]
image = tf.cast(image, tf.float32)
image = image / 255.
# Resize the images to 224 x 224
image = tf.image.resize(image, (224, 224))
data['image'] = image
return data
我们看一下数据集中的一些样本
batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)
模型
让我们从 TF-Hub 中加载分类器并获取一些预测结果,然后查看模型对一些样本的预测
classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11. WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
plot(examples, predictions)
评估和鲁棒性
我们来衡量分类器在拆分数据集上的准确率。我们还可以通过评估模型在非木薯数据集上的性能来评估其鲁棒性。对于 iNaturalist 或豆科植物等其他植物数据集中的图像,模型应几乎始终返回未知。
Parameters
DATASET = 'cassava'
DATASET_SPLIT = 'test'
BATCH_SIZE = 32
MAX_EXAMPLES = 1000
def label_to_unknown_fn(data):
data['label'] = 5 # Override label to unknown.
return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
ds = ds.map(label_to_unknown_fn)
dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)
# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
probabilities = classifier(examples['image'])
predictions = tf.math.argmax(probabilities, axis=-1)
labels = examples['label']
metric.update_state(labels, predictions)
print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88
了解更多
- 详细了解 TensorFlow Hub 上的模型:https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
- 了解如何使用模型的 TensorFlow Lite 版本通过 ML Kit 构建在手机上运行的自定义图像分类器。