CropNet: rilevamento della malattia della manioca

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza su GitHub Scarica taccuino Vedi il modello del mozzo TF

Questa mostra notebook come usare il CropNet malattia classificatore manioca modello da tensorflow Hub. Le modello classifica immagini di foglie di manioca in una delle 6 classi: ruggine batterica, malattia striscia marrone, acari verde, la malattia del mosaico, in buona salute, o sconosciute.

Questa collaborazione mostra come:

  • Caricare il https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 modello da tensorflow Hub
  • Caricare la manioca set di dati da tensorflow Datasets (TFDS)
  • Classifica le immagini delle foglie di manioca in 4 distinte categorie di malattie della manioca o come sane o sconosciute.
  • Valutare l'accuratezza del classificatore e sguardo a come robusto il modello è quando applicata a di immagini di dominio.

Importazioni e configurazione

pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

Funzione di supporto per la visualizzazione di esempi

set di dati

Carico di lasciare che il set di dati da manioca TFDS

dataset, info = tfds.load('cassava', with_info=True)

Diamo un'occhiata alle informazioni sul set di dati per saperne di più, come la descrizione e la citazione e le informazioni su quanti esempi sono disponibili

info
tfds.core.DatasetInfo(
    name='cassava',
    full_name='cassava/0.1.0',
    description="""
    Cassava consists of leaf images for the cassava plant depicting healthy and
    four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial
    Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD).
    Dataset consists of a total of 9430 labelled images.
    The 9430 labelled images are split into a training set (5656), a test set(1885)
    and a validation set (1889). The number of images per class are unbalanced with
    the two disease classes CMD and CBSD having 72% of the images.
    """,
    homepage='https://www.kaggle.com/c/cassava-disease/overview',
    data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0',
    download_size=1.26 GiB,
    dataset_size=Unknown size,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=1885, num_shards=4>,
        'train': <SplitInfo num_examples=5656, num_shards=8>,
        'validation': <SplitInfo num_examples=1889, num_shards=4>,
    },
    citation="""@misc{mwebaze2019icassava,
        title={iCassava 2019Fine-Grained Visual Categorization Challenge},
        author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira},
        year={2019},
        eprint={1908.02900},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
    }""",
)

Il set di dati manioca ha immagini di foglie di manioca con 4 malattie distinte e foglie di manioca sani. Il modello può prevedere tutte queste classi e la sesta classe per "sconosciuto" quando il modello non è sicuro della sua previsione.

# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']

# Map the class names to human readable names
name_map = dict(
    cmd='Mosaic Disease',
    cbb='Bacterial Blight',
    cgm='Green Mite',
    cbsd='Brown Streak Disease',
    healthy='Healthy',
    unknown='Unknown')

print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes:
['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown']
['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']

Prima di poter fornire i dati al modello, è necessario eseguire un po' di pre-elaborazione. Il modello prevede immagini 224 x 224 con valori del canale RGB in [0, 1]. Normalizziamo e ridimensioniamo le immagini.

def preprocess_fn(data):
  image = data['image']

  # Normalize [0, 255] to [0, 1]
  image = tf.cast(image, tf.float32)
  image = image / 255.

  # Resize the images to 224 x 224
  image = tf.image.resize(image, (224, 224))

  data['image'] = image
  return data

Diamo un'occhiata ad alcuni esempi dal set di dati

batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)

png

Modello

Carichiamo il classificatore da TF Hub e otteniamo alcune previsioni e vediamo le previsioni del modello su alcuni esempi

classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)

png

Valutazione e robustezza

Diamo misurare l'accuratezza del nostro classificatore su una scissione del set di dati. Possiamo anche guardare la robustezza del modello valutando le sue prestazioni su un set di dati non manioca. Per l'immagine di altre serie di dati vegetali come iNaturalist o fagioli, il modello dovrebbe quasi sempre tornare sconosciuta.

Parametri

def label_to_unknown_fn(data):
  data['label'] = 5  # Override label to unknown.
  return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
  ds = ds.map(label_to_unknown_fn)
  dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)

# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
  probabilities = classifier(examples['image'])
  predictions = tf.math.argmax(probabilities, axis=-1)
  labels = examples['label']
  metric.update_state(labels, predictions)

print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88

Scopri di più