Addestrare una rete neurale su MNIST con Keras

Questo semplice esempio mostra come collegare TensorFlow Datasets (TFDS) in un modello Keras.

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza l'origine su GitHub Scarica quaderno
import tensorflow as tf
import tensorflow_datasets as tfds

Passaggio 1: crea la pipeline di input

Inizia creando una pipeline di input efficiente utilizzando i consigli di:

Carica un set di dati

Carica il set di dati MNIST con i seguenti argomenti:

  • shuffle_files=True : i dati MNIST sono archiviati solo in un singolo file, ma per set di dati più grandi con più file su disco, è buona norma mescolarli durante l'allenamento.
  • as_supervised=True : restituisce una tupla (img, label) invece di un dizionario {'image': img, 'label': label} .
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
2022-02-07 04:05:46.671689: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Costruisci una pipeline di formazione

Applicare le seguenti trasformazioni:

  • tf.data.Dataset.map : TFDS fornisce immagini di tipo tf.uint8 , mentre il modello si aspetta tf.float32 . Pertanto, è necessario normalizzare le immagini.
  • tf.data.Dataset.cache Quando si inserisce il set di dati in memoria, memorizzarlo nella cache prima di mescolarlo per ottenere prestazioni migliori.
    Nota: le trasformazioni casuali devono essere applicate dopo la memorizzazione nella cache.
  • tf.data.Dataset.shuffle : per una vera casualità, imposta il buffer shuffle sulla dimensione completa del set di dati.
    Nota: per set di dati di grandi dimensioni che non possono stare in memoria, usa buffer_size=1000 se il tuo sistema lo consente.
  • tf.data.Dataset.batch : batch di elementi del set di dati dopo aver mescolato per ottenere batch univoci in ogni epoca.
  • tf.data.Dataset.prefetch : è buona norma terminare la pipeline precaricando le prestazioni .
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

Costruisci una pipeline di valutazione

La pipeline di test è simile alla pipeline di formazione con piccole differenze:

  • Non è necessario chiamare tf.data.Dataset.shuffle .
  • La memorizzazione nella cache viene eseguita dopo il batch perché i batch possono essere gli stessi tra epoche.
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

Passaggio 2: crea e addestra il modello

Collega la pipeline di input TFDS a un semplice modello Keras, compila il modello e addestralo.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 5s 4ms/step - loss: 0.3503 - sparse_categorical_accuracy: 0.9053 - val_loss: 0.1979 - val_sparse_categorical_accuracy: 0.9415
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1668 - sparse_categorical_accuracy: 0.9524 - val_loss: 0.1392 - val_sparse_categorical_accuracy: 0.9595
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1216 - sparse_categorical_accuracy: 0.9657 - val_loss: 0.1120 - val_sparse_categorical_accuracy: 0.9653
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0939 - sparse_categorical_accuracy: 0.9726 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9704
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0757 - sparse_categorical_accuracy: 0.9781 - val_loss: 0.0928 - val_sparse_categorical_accuracy: 0.9717
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0625 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9728
<keras.callbacks.History at 0x7f77b42cd910>