عرض على TensorFlow.org | تشغيل في Google Colab | عرض على جيثب | تحميل دفتر | انظر نموذج TF Hub |
يظهر دفتر هذا كيفية استخدام CropNet الكسافا المصنف المرض نموذج من TensorFlow المحور. ويصنف نموذج صور لأوراق الكسافا في واحدة من 6 فصول: اللفحة البكتيرية والأمراض خط البني، الأخضر سوس، وأمراض الفسيفساء وصحية، أو غير معروف.
يوضح هذا الكولاب كيفية:
- تحميل https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 نموذج من TensorFlow المحور
- تحميل الكسافا بيانات من TensorFlow مجموعات البيانات (TFDS)
- صنف صور أوراق الكسافا إلى 4 فئات متميزة من مرض الكسافا أو صور صحية أو غير معروفة.
- تقييم دقة المصنف ونظرة على مدى قوة هذا النموذج هو عندما يطبق على من الصور المجال.
الواردات والإعداد
pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
وظيفة المساعد لعرض الأمثلة
def plot(examples, predictions=None):
# Get the images, labels, and optionally predictions
images = examples['image']
labels = examples['label']
batch_size = len(images)
if predictions is None:
predictions = batch_size * [None]
# Configure the layout of the grid
x = np.ceil(np.sqrt(batch_size))
y = np.ceil(batch_size / x)
fig = plt.figure(figsize=(x * 6, y * 7))
for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
# Render the image
ax = fig.add_subplot(x, y, i+1)
ax.imshow(image, aspect='auto')
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
# Display the label and optionally prediction
x_label = 'Label: ' + name_map[class_names[label]]
if prediction is not None:
x_label = 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
ax.xaxis.label.set_color('green' if label == prediction else 'red')
ax.set_xlabel(x_label)
plt.show()
مجموعة البيانات
تحميل دعونا مجموعة البيانات الكسافا من TFDS
dataset, info = tfds.load('cassava', with_info=True)
دعنا نلقي نظرة على معلومات مجموعة البيانات لمعرفة المزيد عنها ، مثل الوصف والاستشهاد والمعلومات حول عدد الأمثلة المتاحة
info
tfds.core.DatasetInfo( name='cassava', full_name='cassava/0.1.0', description=""" Cassava consists of leaf images for the cassava plant depicting healthy and four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD). Dataset consists of a total of 9430 labelled images. The 9430 labelled images are split into a training set (5656), a test set(1885) and a validation set (1889). The number of images per class are unbalanced with the two disease classes CMD and CBSD having 72% of the images. """, homepage='https://www.kaggle.com/c/cassava-disease/overview', data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0', download_size=1.26 GiB, dataset_size=Unknown size, features=FeaturesDict({ 'image': Image(shape=(None, None, 3), dtype=tf.uint8), 'image/filename': Text(shape=(), dtype=tf.string), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'test': <SplitInfo num_examples=1885, num_shards=4>, 'train': <SplitInfo num_examples=5656, num_shards=8>, 'validation': <SplitInfo num_examples=1889, num_shards=4>, }, citation="""@misc{mwebaze2019icassava, title={iCassava 2019Fine-Grained Visual Categorization Challenge}, author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira}, year={2019}, eprint={1908.02900}, archivePrefix={arXiv}, primaryClass={cs.CV} }""", )
مجموعة البيانات الكسافا لديها صورا لأوراق الكسافا مع 4 أمراض متميزة وكذلك أوراق الكسافا صحية. يمكن للنموذج أن يتنبأ بكل هذه الفئات بالإضافة إلى الفئة السادسة لـ "غير معروف" عندما لا يكون النموذج واثقًا من تنبؤاته.
# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']
# Map the class names to human readable names
name_map = dict(
cmd='Mosaic Disease',
cbb='Bacterial Blight',
cgm='Green Mite',
cbsd='Brown Streak Disease',
healthy='Healthy',
unknown='Unknown')
print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes: ['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown'] ['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']
قبل أن نتمكن من تغذية النموذج بالبيانات ، نحتاج إلى القيام ببعض المعالجة المسبقة. يتوقع النموذج صور 224 × 224 بقيم قناة RGB في [0 ، 1]. لنقم بتطبيع الصور وتغيير حجمها.
def preprocess_fn(data):
image = data['image']
# Normalize [0, 255] to [0, 1]
image = tf.cast(image, tf.float32)
image = image / 255.
# Resize the images to 224 x 224
image = tf.image.resize(image, (224, 224))
data['image'] = image
return data
دعنا نلقي نظرة على بعض الأمثلة من مجموعة البيانات
batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)
نموذج
دعنا نحمل المصنف من TF Hub ونحصل على بعض التنبؤات ونرى تنبؤات النموذج في بعض الأمثلة
classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)
التقييم والمتانة
دعونا قياس دقة المصنف لدينا على تقسيم مجموعة البيانات. يمكننا أيضا أن ننظر إلى متانة نموذج من خلال تقييم أدائها على مجموعة بيانات غير الكسافا. للحصول على صورة من مجموعات البيانات النباتية الأخرى مثل iNaturalist أو الفاصوليا، ينبغي للنموذج يعود دائما تقريبا غير معروف.
المعلمات
DATASET = 'cassava'
DATASET_SPLIT = 'test'
BATCH_SIZE = 32
MAX_EXAMPLES = 1000
def label_to_unknown_fn(data):
data['label'] = 5 # Override label to unknown.
return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
ds = ds.map(label_to_unknown_fn)
dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)
# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
probabilities = classifier(examples['image'])
predictions = tf.math.argmax(probabilities, axis=-1)
labels = examples['label']
metric.update_state(labels, predictions)
print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88
يتعلم أكثر
- تعرف على المزيد حول نموذج على TensorFlow المحور: https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
- تعلم كيفية بناء صورة مخصصة المصنف تشغيل على الهاتف المحمول مع ML كيت مع نسخة TensorFlow لايت من هذا النموذج .