Training a neural network on MNIST with Keras
使用集合让一切井井有条
根据您的偏好保存内容并对其进行分类。
This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.
import tensorflow as tf
import tensorflow_datasets as tfds
第 1 步:创建输入流水线
首先,使用以下指南中的建议构建有效的输入流水线:
加载数据集
使用以下参数加载 MNIST 数据集:
shuffle_files=True
:MNIST 数据仅存储在单个文件中,但是对于大型数据集则会以多个文件存储在磁盘中,在训练时最好将它们打乱顺序。
as_supervised=True
:返回元组 (img, label)
而非字典 {'image': img, 'label': label}
。
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
构建训练流水线
应用以下转换:
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., label
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
构建评估流水线
您的测试流水线与训练流水线类似,只有几点细微差异:
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)
第 2 步:创建并训练模型
将 TFDS 输入流水线插入一个简单的 Keras 模型、编译模型并训练它。
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
model.fit(
ds_train,
epochs=6,
validation_data=ds_test,
)
如未另行说明,那么本页面中的内容已根据知识共享署名 4.0 许可获得了许可,并且代码示例已根据 Apache 2.0 许可获得了许可。有关详情,请参阅 Google 开发者网站政策。Java 是 Oracle 和/或其关联公司的注册商标。
最后更新时间 (UTC):2024-01-11。
[null,null,["最后更新时间 (UTC):2024-01-11。"],[],[],null,["# Training a neural network on MNIST with Keras\n\n\u003cbr /\u003e\n\nThis simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.\n\n|-----------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------|\n| [View on TensorFlow.org](https://www.tensorflow.org/datasets/keras_example) | [Run in Google Colab](https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/keras_example.ipynb) | [View source on GitHub](https://github.com/tensorflow/datasets/blob/master/docs/keras_example.ipynb) | [Download notebook](https://storage.googleapis.com/tensorflow_docs/datasets/docs/keras_example.ipynb) |\n\n import tensorflow as tf\n import tensorflow_datasets as tfds\n\n```\n2025-08-06 11:39:27.068270: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\nWARNING: All log messages before absl::InitializeLog() is called are written to STDERR\nE0000 00:00:1754480367.093425 22003 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\nE0000 00:00:1754480367.101768 22003 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\nW0000 00:00:1754480367.121188 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\nW0000 00:00:1754480367.121213 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\nW0000 00:00:1754480367.121216 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\nW0000 00:00:1754480367.121218 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n```\n\nStep 1: Create your input pipeline\n----------------------------------\n\nStart by building an efficient input pipeline using advices from:\n\n- The [Performance tips](https://www.tensorflow.org/datasets/performances) guide\n- The [Better performance with the `tf.data` API](https://www.tensorflow.org/guide/data_performance#optimize_performance) guide\n\n### Load a dataset\n\nLoad the MNIST dataset with the following arguments:\n\n- `shuffle_files=True`: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.\n- `as_supervised=True`: Returns a tuple `(img, label)` instead of a dictionary `{'image': img, 'label': label}`.\n\n (ds_train, ds_test), ds_info = tfds.load(\n 'mnist',\n split=['train', 'test'],\n shuffle_files=True,\n as_supervised=True,\n with_info=True,\n )\n\n```\n2025-08-06 11:39:31.458052: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n```\n\n### Build a training pipeline\n\nApply the following transformations:\n\n- [`tf.data.Dataset.map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map): TFDS provide images of type [`tf.uint8`](https://www.tensorflow.org/api_docs/python/tf#uint8), while the model expects [`tf.float32`](https://www.tensorflow.org/api_docs/python/tf#float32). Therefore, you need to normalize images.\n- [`tf.data.Dataset.cache`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache) As you fit the dataset in memory, cache it before shuffling for a better performance. \n **Note:** Random transformations should be applied after caching.\n- [`tf.data.Dataset.shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle): For true randomness, set the shuffle buffer to the full dataset size. \n **Note:** For large datasets that can't fit in memory, use `buffer_size=1000` if your system allows it.\n- [`tf.data.Dataset.batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch): Batch elements of the dataset after shuffling to get unique batches at each epoch.\n- [`tf.data.Dataset.prefetch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#prefetch): It is good practice to end the pipeline by prefetching [for performance](https://www.tensorflow.org/guide/data_performance#prefetching).\n\n def normalize_img(image, label):\n \"\"\"Normalizes images: `uint8` -\u003e `float32`.\"\"\"\n return tf.cast(image, tf.float32) / 255., label\n\n ds_train = ds_train.map(\n normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n ds_train = ds_train.cache()\n ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n ds_train = ds_train.batch(128)\n ds_train = ds_train.prefetch(tf.data.AUTOTUNE)\n\n### Build an evaluation pipeline\n\nYour testing pipeline is similar to the training pipeline with small differences:\n\n- You don't need to call [`tf.data.Dataset.shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle).\n- Caching is done after batching because batches can be the same between epochs.\n\n ds_test = ds_test.map(\n normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n ds_test = ds_test.batch(128)\n ds_test = ds_test.cache()\n ds_test = ds_test.prefetch(tf.data.AUTOTUNE)\n\nStep 2: Create and train the model\n----------------------------------\n\nPlug the TFDS input pipeline into a simple Keras model, compile the model, and train it. \n\n model = tf.keras.models.Sequential([\n tf.keras.layers.Flatten(input_shape=(28, 28)),\n tf.keras.layers.Dense(128, activation='relu'),\n tf.keras.layers.Dense(10)\n ])\n model.compile(\n optimizer=tf.keras.optimizers.Adam(0.001),\n loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n )\n\n model.fit(\n ds_train,\n epochs=6,\n validation_data=ds_test,\n )\n\n```\n/tmpfs/src/tf_docs_env/lib/python3.10/site-packages/keras/src/layers/reshaping/flatten.py:37: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n super().__init__(**kwargs)\nEpoch 1/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - loss: 0.3653 - sparse_categorical_accuracy: 0.8988 - val_loss: 0.1945 - val_sparse_categorical_accuracy: 0.9443\nEpoch 2/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.1658 - sparse_categorical_accuracy: 0.9530 - val_loss: 0.1354 - val_sparse_categorical_accuracy: 0.9597\nEpoch 3/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.1201 - sparse_categorical_accuracy: 0.9658 - val_loss: 0.1136 - val_sparse_categorical_accuracy: 0.9666\nEpoch 4/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0929 - sparse_categorical_accuracy: 0.9738 - val_loss: 0.0996 - val_sparse_categorical_accuracy: 0.9699\nEpoch 5/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0756 - sparse_categorical_accuracy: 0.9785 - val_loss: 0.0899 - val_sparse_categorical_accuracy: 0.9727\nEpoch 6/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0622 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0833 - val_sparse_categorical_accuracy: 0.9753\n\u003ckeras.src.callbacks.history.History at 0x7fbcbcf1ddb0\u003e\n```"]]