Training a neural network on MNIST with Keras
コレクションでコンテンツを整理
必要に応じて、コンテンツの保存と分類を行います。
This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.
import tensorflow as tf
import tensorflow_datasets as tfds
手順 1: 入力パイプラインを作成する
まず、次のガイドを参照し、有効な入力パイプラインを構築します。
データセットを読み込む
次の引数を使って MNIST データセットを読み込みます。
shuffle_files
: MNIST データは、単一のファイルにのみ保存されていますが、ディスク上の複数のファイルを伴うより大きなデータセットについては、トレーニングの際にシャッフルすることが良い実践です。
as_supervised
: dict {'image': img, 'label': label}
の代わりに tuple (img, label)
を返します。
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
トレーニングパイプラインを構築する
次の変換を適用します。
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., label
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
評価パイプラインを構築する
テストのパイプラインはトレーニングのパイプラインと似ていますが、次のようにわずかに異なります。
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)
手順 2: モデルを作成してトレーニングする
TFDS 入力パイプラインを簡単な Keras モデルにプラグインし、モデルをコンパイルしてトレーニングします。
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
model.fit(
ds_train,
epochs=6,
validation_data=ds_test,
)
特に記載のない限り、このページのコンテンツはクリエイティブ・コモンズの表示 4.0 ライセンスにより使用許諾されます。コードサンプルは Apache 2.0 ライセンスにより使用許諾されます。詳しくは、Google Developers サイトのポリシーをご覧ください。Java は Oracle および関連会社の登録商標です。
最終更新日 2024-01-16 UTC。
[null,null,["最終更新日 2024-01-16 UTC。"],[],[],null,["# Training a neural network on MNIST with Keras\n\n\u003cbr /\u003e\n\nThis simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.\n\n|-----------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------|\n| [View on TensorFlow.org](https://www.tensorflow.org/datasets/keras_example) | [Run in Google Colab](https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/keras_example.ipynb) | [View source on GitHub](https://github.com/tensorflow/datasets/blob/master/docs/keras_example.ipynb) | [Download notebook](https://storage.googleapis.com/tensorflow_docs/datasets/docs/keras_example.ipynb) |\n\n import tensorflow as tf\n import tensorflow_datasets as tfds\n\n```\n2025-08-06 11:39:27.068270: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\nWARNING: All log messages before absl::InitializeLog() is called are written to STDERR\nE0000 00:00:1754480367.093425 22003 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\nE0000 00:00:1754480367.101768 22003 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\nW0000 00:00:1754480367.121188 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\nW0000 00:00:1754480367.121213 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\nW0000 00:00:1754480367.121216 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\nW0000 00:00:1754480367.121218 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n```\n\nStep 1: Create your input pipeline\n----------------------------------\n\nStart by building an efficient input pipeline using advices from:\n\n- The [Performance tips](https://www.tensorflow.org/datasets/performances) guide\n- The [Better performance with the `tf.data` API](https://www.tensorflow.org/guide/data_performance#optimize_performance) guide\n\n### Load a dataset\n\nLoad the MNIST dataset with the following arguments:\n\n- `shuffle_files=True`: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.\n- `as_supervised=True`: Returns a tuple `(img, label)` instead of a dictionary `{'image': img, 'label': label}`.\n\n (ds_train, ds_test), ds_info = tfds.load(\n 'mnist',\n split=['train', 'test'],\n shuffle_files=True,\n as_supervised=True,\n with_info=True,\n )\n\n```\n2025-08-06 11:39:31.458052: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n```\n\n### Build a training pipeline\n\nApply the following transformations:\n\n- [`tf.data.Dataset.map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map): TFDS provide images of type [`tf.uint8`](https://www.tensorflow.org/api_docs/python/tf#uint8), while the model expects [`tf.float32`](https://www.tensorflow.org/api_docs/python/tf#float32). Therefore, you need to normalize images.\n- [`tf.data.Dataset.cache`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache) As you fit the dataset in memory, cache it before shuffling for a better performance. \n **Note:** Random transformations should be applied after caching.\n- [`tf.data.Dataset.shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle): For true randomness, set the shuffle buffer to the full dataset size. \n **Note:** For large datasets that can't fit in memory, use `buffer_size=1000` if your system allows it.\n- [`tf.data.Dataset.batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch): Batch elements of the dataset after shuffling to get unique batches at each epoch.\n- [`tf.data.Dataset.prefetch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#prefetch): It is good practice to end the pipeline by prefetching [for performance](https://www.tensorflow.org/guide/data_performance#prefetching).\n\n def normalize_img(image, label):\n \"\"\"Normalizes images: `uint8` -\u003e `float32`.\"\"\"\n return tf.cast(image, tf.float32) / 255., label\n\n ds_train = ds_train.map(\n normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n ds_train = ds_train.cache()\n ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n ds_train = ds_train.batch(128)\n ds_train = ds_train.prefetch(tf.data.AUTOTUNE)\n\n### Build an evaluation pipeline\n\nYour testing pipeline is similar to the training pipeline with small differences:\n\n- You don't need to call [`tf.data.Dataset.shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle).\n- Caching is done after batching because batches can be the same between epochs.\n\n ds_test = ds_test.map(\n normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n ds_test = ds_test.batch(128)\n ds_test = ds_test.cache()\n ds_test = ds_test.prefetch(tf.data.AUTOTUNE)\n\nStep 2: Create and train the model\n----------------------------------\n\nPlug the TFDS input pipeline into a simple Keras model, compile the model, and train it. \n\n model = tf.keras.models.Sequential([\n tf.keras.layers.Flatten(input_shape=(28, 28)),\n tf.keras.layers.Dense(128, activation='relu'),\n tf.keras.layers.Dense(10)\n ])\n model.compile(\n optimizer=tf.keras.optimizers.Adam(0.001),\n loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n )\n\n model.fit(\n ds_train,\n epochs=6,\n validation_data=ds_test,\n )\n\n```\n/tmpfs/src/tf_docs_env/lib/python3.10/site-packages/keras/src/layers/reshaping/flatten.py:37: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n super().__init__(**kwargs)\nEpoch 1/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - loss: 0.3653 - sparse_categorical_accuracy: 0.8988 - val_loss: 0.1945 - val_sparse_categorical_accuracy: 0.9443\nEpoch 2/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.1658 - sparse_categorical_accuracy: 0.9530 - val_loss: 0.1354 - val_sparse_categorical_accuracy: 0.9597\nEpoch 3/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.1201 - sparse_categorical_accuracy: 0.9658 - val_loss: 0.1136 - val_sparse_categorical_accuracy: 0.9666\nEpoch 4/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0929 - sparse_categorical_accuracy: 0.9738 - val_loss: 0.0996 - val_sparse_categorical_accuracy: 0.9699\nEpoch 5/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0756 - sparse_categorical_accuracy: 0.9785 - val_loss: 0.0899 - val_sparse_categorical_accuracy: 0.9727\nEpoch 6/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0622 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0833 - val_sparse_categorical_accuracy: 0.9753\n\u003ckeras.src.callbacks.history.History at 0x7fbcbcf1ddb0\u003e\n```"]]