CropNet: Detecção da doença da mandioca

Este notebook mostra como usar o CropNet doença classificador mandioca modelo de TensorFlow Hub. Os classifica modelo imagens de folhas de mandioca em uma das 6 classes: ferrugem bacteriana, doença podridão radicular, ácaro verde, doença do mosaico, saudável, ou desconhecida.

Esta colab demonstra como:

  • Carregar o https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 modelo de TensorFlow Hub
  • Carregar a mandioca conjunto de dados do TensorFlow conjuntos de dados (TFDS)
  • Classifique as imagens de folhas de mandioca em 4 categorias distintas de doenças da mandioca ou como saudáveis ​​ou desconhecidas.
  • Avaliar a precisão do classificador e ver como robusto o modelo é quando aplicado a partir de imagens de domínio.

Importações e configuração

pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

Função auxiliar para exibir exemplos

def plot(examples, predictions=None):
 
# Get the images, labels, and optionally predictions
  images
= examples['image']
  labels
= examples['label']
  batch_size
= len(images)
 
if predictions is None:
    predictions
= batch_size * [None]

 
# Configure the layout of the grid
  x
= np.ceil(np.sqrt(batch_size))
  y
= np.ceil(batch_size / x)
  fig
= plt.figure(figsize=(x * 6, y * 7))

 
for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
   
# Render the image
    ax
= fig.add_subplot(x, y, i+1)
    ax
.imshow(image, aspect='auto')
    ax
.grid(False)
    ax
.set_xticks([])
    ax
.set_yticks([])

   
# Display the label and optionally prediction
    x_label
= 'Label: ' + name_map[class_names[label]]
   
if prediction is not None:
      x_label
= 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
      ax
.xaxis.label.set_color('green' if label == prediction else 'red')
    ax
.set_xlabel(x_label)

  plt
.show()

Conjunto de dados

Vamos carga do conjunto de dados de mandioca a partir TFDS

dataset, info = tfds.load('cassava', with_info=True)

Vamos dar uma olhada nas informações do conjunto de dados para aprender mais sobre ele, como a descrição e citação e informações sobre quantos exemplos estão disponíveis

info
tfds.core.DatasetInfo(
    name='cassava',
    full_name='cassava/0.1.0',
    description="""
    Cassava consists of leaf images for the cassava plant depicting healthy and
    four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial
    Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD).
    Dataset consists of a total of 9430 labelled images.
    The 9430 labelled images are split into a training set (5656), a test set(1885)
    and a validation set (1889). The number of images per class are unbalanced with
    the two disease classes CMD and CBSD having 72% of the images.
    """,
    homepage='https://www.kaggle.com/c/cassava-disease/overview',
    data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0',
    download_size=1.26 GiB,
    dataset_size=Unknown size,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=1885, num_shards=4>,
        'train': <SplitInfo num_examples=5656, num_shards=8>,
        'validation': <SplitInfo num_examples=1889, num_shards=4>,
    },
    citation="""@misc{mwebaze2019icassava,
        title={iCassava 2019Fine-Grained Visual Categorization Challenge},
        author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira},
        year={2019},
        eprint={1908.02900},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
    }""",
)

O conjunto de dados de mandioca tem imagens de folhas de mandioca com 4 doenças distintas, bem como folhas de mandioca saudáveis. O modelo pode prever todas essas classes, bem como a sexta classe para "desconhecido", quando o modelo não está confiante em sua previsão.

# Extend the cassava dataset classes with 'unknown'
class_names
= info.features['label'].names + ['unknown']

# Map the class names to human readable names
name_map
= dict(
    cmd
='Mosaic Disease',
    cbb
='Bacterial Blight',
    cgm
='Green Mite',
    cbsd
='Brown Streak Disease',
    healthy
='Healthy',
    unknown
='Unknown')

print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes:
['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown']
['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']

Antes de alimentarmos o modelo com os dados, precisamos fazer um pouco de pré-processamento. O modelo espera 224 x 224 imagens com valores de canal RGB em [0, 1]. Vamos normalizar e redimensionar as imagens.

def preprocess_fn(data):
  image
= data['image']

 
# Normalize [0, 255] to [0, 1]
  image
= tf.cast(image, tf.float32)
  image
= image / 255.

 
# Resize the images to 224 x 224
  image
= tf.image.resize(image, (224, 224))

  data
['image'] = image
 
return data

Vamos dar uma olhada em alguns exemplos do conjunto de dados

batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples
= next(batch)
plot
(examples)

png

Modelo

Vamos carregar o classificador do TF Hub e obter algumas previsões e ver as previsões do modelo em alguns exemplos

classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities
= classifier(examples['image'])
predictions
= tf.argmax(probabilities, axis=-1)
plot(examples, predictions)

png

Avaliação e robustez

Vamos medir a precisão do nosso classificador em uma divisão do conjunto de dados. Podemos também olhar para a robustez do modelo de avaliação do seu desempenho em um conjunto de dados não-mandioca. Para a imagem de outros conjuntos de dados de plantas como iNaturalist ou feijão, o modelo deve quase sempre retornar desconhecido.

Parâmetros

DATASET = 'cassava' 
DATASET_SPLIT
= 'test'
BATCH_SIZE
=  32
MAX_EXAMPLES
= 1000

def label_to_unknown_fn(data):
  data
['label'] = 5  # Override label to unknown.
 
return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds
= tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description
= DATASET
if DATASET != 'cassava':
  ds
= ds.map(label_to_unknown_fn)
  dataset_description
+= ' (labels mapped to unknown)'
ds
= ds.batch(BATCH_SIZE)

# Calculate the accuracy of the model
metric
= tf.keras.metrics.Accuracy()
for examples in ds:
  probabilities
= classifier(examples['image'])
  predictions
= tf.math.argmax(probabilities, axis=-1)
  labels
= examples['label']
  metric
.update_state(labels, predictions)

print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88

Saber mais