CropNet: تشخیص بیماری کاساوا

این نوت بوک نشان می دهد که چگونه به استفاده از CropNet این گونه گیاهان به طبقه بندی بیماری مدل از TensorFlow توپی. طبقه بندی مدل تصاویر از برگ کاساوا را به یکی از 6 کلاس: سوختگی باکتریایی، بیماری رگه های قهوه ای، کرم های سبز، بیماری موزاییک، سالم، و یا ناشناخته است.

این مجموعه نشان می دهد که چگونه:

  • بارگذاری https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 مدل از TensorFlow توپی
  • بارگذاری این گونه گیاهان به مجموعه داده از TensorFlow مجموعه داده (TFDS)
  • تصاویر برگ های کاساوا را به 4 دسته بیماری کاساوا مجزا یا سالم یا ناشناخته طبقه بندی کنید.
  • بررسی دقت طبقه بندی و نگاه کنید که چگونه قوی مدل است که به خارج از تصاویر مالکیت اعمال می شود.

واردات و راه اندازی

pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

تابع کمکی برای نمایش نمونه ها

def plot(examples, predictions=None):
 
# Get the images, labels, and optionally predictions
  images
= examples['image']
  labels
= examples['label']
  batch_size
= len(images)
 
if predictions is None:
    predictions
= batch_size * [None]

 
# Configure the layout of the grid
  x
= np.ceil(np.sqrt(batch_size))
  y
= np.ceil(batch_size / x)
  fig
= plt.figure(figsize=(x * 6, y * 7))

 
for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
   
# Render the image
    ax
= fig.add_subplot(x, y, i+1)
    ax
.imshow(image, aspect='auto')
    ax
.grid(False)
    ax
.set_xticks([])
    ax
.set_yticks([])

   
# Display the label and optionally prediction
    x_label
= 'Label: ' + name_map[class_names[label]]
   
if prediction is not None:
      x_label
= 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
      ax
.xaxis.label.set_color('green' if label == prediction else 'red')
    ax
.set_xlabel(x_label)

  plt
.show()

مجموعه داده

بار بیایید مجموعه داده مانیوک از TFDS

dataset, info = tfds.load('cassava', with_info=True)

بیایید نگاهی به اطلاعات مجموعه داده بیندازیم تا درباره آن اطلاعات بیشتری کسب کنیم، مانند توضیحات و نقل قول و اطلاعات در مورد تعداد نمونه های موجود

info
tfds.core.DatasetInfo(
    name='cassava',
    full_name='cassava/0.1.0',
    description="""
    Cassava consists of leaf images for the cassava plant depicting healthy and
    four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial
    Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD).
    Dataset consists of a total of 9430 labelled images.
    The 9430 labelled images are split into a training set (5656), a test set(1885)
    and a validation set (1889). The number of images per class are unbalanced with
    the two disease classes CMD and CBSD having 72% of the images.
    """,
    homepage='https://www.kaggle.com/c/cassava-disease/overview',
    data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0',
    download_size=1.26 GiB,
    dataset_size=Unknown size,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=1885, num_shards=4>,
        'train': <SplitInfo num_examples=5656, num_shards=8>,
        'validation': <SplitInfo num_examples=1889, num_shards=4>,
    },
    citation="""@misc{mwebaze2019icassava,
        title={iCassava 2019Fine-Grained Visual Categorization Challenge},
        author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira},
        year={2019},
        eprint={1908.02900},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
    }""",
)

مجموعه داده کاساوا دارای تصاویر از برگ کاساوا با 4 بیماری مجزا نیز برگ کاساوا سالم است. مدل می تواند تمام این کلاس ها و همچنین کلاس ششم را برای "ناشناخته" پیش بینی کند، زمانی که مدل به پیش بینی خود اطمینان ندارد.

# Extend the cassava dataset classes with 'unknown'
class_names
= info.features['label'].names + ['unknown']

# Map the class names to human readable names
name_map
= dict(
    cmd
='Mosaic Disease',
    cbb
='Bacterial Blight',
    cgm
='Green Mite',
    cbsd
='Brown Streak Disease',
    healthy
='Healthy',
    unknown
='Unknown')

print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes:
['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown']
['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']

قبل از اینکه بتوانیم داده ها را به مدل تغذیه کنیم، باید کمی پیش پردازش انجام دهیم. مدل انتظار دارد تصاویر 224 x 224 با مقادیر کانال RGB در [0، 1] باشد. بیایید تصاویر را عادی و اندازه آنها را تغییر دهیم.

def preprocess_fn(data):
  image
= data['image']

 
# Normalize [0, 255] to [0, 1]
  image
= tf.cast(image, tf.float32)
  image
= image / 255.

 
# Resize the images to 224 x 224
  image
= tf.image.resize(image, (224, 224))

  data
['image'] = image
 
return data

بیایید به چند نمونه از مجموعه داده نگاهی بیندازیم

batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples
= next(batch)
plot
(examples)

png

مدل

بیایید طبقه بندی کننده را از TF Hub بارگذاری کنیم و چند پیش بینی دریافت کنیم و ببینیم که پیش بینی های مدل در چند نمونه است.

classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities
= classifier(examples['image'])
predictions
= tf.argmax(probabilities, axis=-1)
plot(examples, predictions)

png

ارزیابی و استحکام

بیایید اندازه گیری دقت طبقه بندی ما در یک تقسیم از مجموعه داده. ما همچنین می توانید در نیرومندی از مدل های ارزیابی عملکرد آن بر یک مجموعه داده غیر این گونه گیاهان به نگاه. برای تصویر دیگر مجموعه داده گیاهی مانند iNaturalist یا لوبیا، مدل تقریبا همیشه باید ناشناخته بازگشت.

مولفه های

DATASET = 'cassava' 
DATASET_SPLIT
= 'test'
BATCH_SIZE
=  32
MAX_EXAMPLES
= 1000

def label_to_unknown_fn(data):
  data
['label'] = 5  # Override label to unknown.
 
return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds
= tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description
= DATASET
if DATASET != 'cassava':
  ds
= ds.map(label_to_unknown_fn)
  dataset_description
+= ' (labels mapped to unknown)'
ds
= ds.batch(BATCH_SIZE)

# Calculate the accuracy of the model
metric
= tf.keras.metrics.Accuracy()
for examples in ds:
  probabilities
= classifier(examples['image'])
  predictions
= tf.math.argmax(probabilities, axis=-1)
  labels
= examples['label']
  metric
.update_state(labels, predictions)

print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88

بیشتر بدانید