概述
tf.distribute.Strategy
API 提供了一个抽象,用于跨多个处理单元进行分布式训练。它允许您使用现有模型和训练代码,只需要很少的修改,就可以执行分布式训练。
本教程演示了如何使用 tf.distribute.MirroredStrategy
在单台机器的多个 GPU 上通过同步训练进行计算图内复制。该策略本质上是将所有模型变量复制到每个处理器。 然后,通过使用全归约来组合所有处理器的梯度,并将组合后的值应用于模型的所有副本。
您将使用 tf.keras
API 构建模型并使用 Model.fit
对其进行训练。(要了解使用自定义训练循环和 MirroredStrategy
的分布式训练,请查看此教程。)
MirroredStrategy
在单台机器上的多个 GPU 上训练您的模型。要在多个工作进程的多个 GPU 上进行同步训练,请通过 Keras Model.fit 或自定义训练循环使用 tf.distribute.MultiWorkerMirroredStrategy
。有关其他选项,请参阅分布式训练指南。
要了解其他各种策略,请参阅使用 TensorFlow 进行分布式训练指南。
安装
import tensorflow_datasets as tfds
import tensorflow as tf
import os
# Load the TensorBoard notebook extension.
%load_ext tensorboard
2023-11-07 23:34:18.893069: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-07 23:34:18.893120: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-07 23:34:18.894855: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
print(tf.__version__)
2.15.0-rc1
下载数据集
从 TensorFlow Datasets 加载 MNIST 数据集。这将返回 tf.data
格式的数据集。
将 with_info
参数设置为 True
会包含整个数据集的元数据,这些元数据将被保存到 info
中。此外,该元数据对象还包括训练样本和测试样本的数量。
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
定义分布式策略
创建 MirroredStrategy
对象。这将处理分布,并提供一个上下文管理器 (MirroredStrategy.scope
) 在内部构建模型。
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 4
设置输入流水线
当使用多个 GPU 训练模型时,可以通过增加批次大小来有效利用额外的计算能力。通常,应使用适合 GPU 内存的最大批次大小,并相应地调整学习率。
# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
定义一个函数,将图像像素值从 [0, 255]
范围归一化到 [0, 1]
范围(特征缩放):
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
将此 scale
函数应用于训练数据和测试数据,使用 tf.data.Dataset
API 对训练数据进行乱序 (Dataset.shuffle
),然后进行分批 (Dataset.batch
)。请注意,您还保留了训练数据的内存缓存以提高性能 (Dataset.cache
).。
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
生成模型
在 Strategy.scope
的上下文中,使用 Keras API 创建和编译模型:
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
定义回调(callback)
定义以下 Keras 回调:
tf.keras.callbacks.TensorBoard
:为 TensorBoard 编写日志,以便呈现计算图。tf.keras.callbacks.ModelCheckpoint
:以特定频率保存模型,例如在每个周期之后。tf.keras.callbacks.BackupAndRestore
:通过备份模型和当前周期号来提供容错功能。在使用 Keras 进行多工作进程训练教程的容错部分了解详情。tf.keras.callbacks.LearningRateScheduler
: schedules the learning rate to change after, for example, every epoch/batch.
出于说明目的,添加名为 PrintLR
的回调以在笔记本中显示学习率。
注: 使用 BackupAndRestore
回调而不是 ModelCheckpoint
作为从作业失败重新启动时还原训练状态的主要机制。由于 BackupAndRestore
仅支持 Eager 模式,在计算图模式下考虑使用 ModelCheckpoint
。
# Define the checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
# Define the name of the checkpoint files.
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Define a function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
if epoch < 3:
return 1e-3
elif epoch >= 3 and epoch < 7:
return 1e-4
else:
return 1e-5
# Define a callback for printing the learning rate at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print('\nLearning rate for epoch {} is {}'.format( epoch + 1, model.optimizer.lr.numpy()))
# Put all the callbacks together.
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()
]
训练和评估
现在,以普通方式训练模型,在模型上调用 Keras Model.fit
并传入在教程开始时创建的数据集。无论您是否分布训练,此步骤相同。
EPOCHS = 12
model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks)
2023-11-07 23:34:25.196020: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. Epoch 1/12 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1699400072.026138 550348 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 1/235 [..............................] - ETA: 25:55 - loss: 2.3030 - accuracy: 0.1133WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0075s vs `on_train_batch_end` time: 0.0133s). Check your callbacks. WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0075s vs `on_train_batch_end` time: 0.0133s). Check your callbacks. 235/235 [==============================] - ETA: 0s - loss: 0.3254 - accuracy: 0.9083INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). Learning rate for epoch 1 is 0.0010000000474974513 235/235 [==============================] - 9s 9ms/step - loss: 0.3254 - accuracy: 0.9083 - lr: 0.0010 Epoch 2/12 230/235 [============================>.] - ETA: 0s - loss: 0.1005 - accuracy: 0.9708 Learning rate for epoch 2 is 0.0010000000474974513 235/235 [==============================] - 2s 7ms/step - loss: 0.1000 - accuracy: 0.9708 - lr: 0.0010 Epoch 3/12 229/235 [============================>.] - ETA: 0s - loss: 0.0674 - accuracy: 0.9801 Learning rate for epoch 3 is 0.0010000000474974513 235/235 [==============================] - 2s 7ms/step - loss: 0.0670 - accuracy: 0.9803 - lr: 0.0010 Epoch 4/12 231/235 [============================>.] - ETA: 0s - loss: 0.0468 - accuracy: 0.9878 Learning rate for epoch 4 is 9.999999747378752e-05 235/235 [==============================] - 2s 7ms/step - loss: 0.0470 - accuracy: 0.9876 - lr: 1.0000e-04 Epoch 5/12 234/235 [============================>.] - ETA: 0s - loss: 0.0445 - accuracy: 0.9882 Learning rate for epoch 5 is 9.999999747378752e-05 235/235 [==============================] - 2s 7ms/step - loss: 0.0445 - accuracy: 0.9882 - lr: 1.0000e-04 Epoch 6/12 235/235 [==============================] - ETA: 0s - loss: 0.0431 - accuracy: 0.9884 Learning rate for epoch 6 is 9.999999747378752e-05 235/235 [==============================] - 2s 7ms/step - loss: 0.0431 - accuracy: 0.9884 - lr: 1.0000e-04 Epoch 7/12 234/235 [============================>.] - ETA: 0s - loss: 0.0414 - accuracy: 0.9891 Learning rate for epoch 7 is 9.999999747378752e-05 235/235 [==============================] - 2s 7ms/step - loss: 0.0415 - accuracy: 0.9890 - lr: 1.0000e-04 Epoch 8/12 229/235 [============================>.] - ETA: 0s - loss: 0.0397 - accuracy: 0.9896 Learning rate for epoch 8 is 9.999999747378752e-06 235/235 [==============================] - 2s 7ms/step - loss: 0.0396 - accuracy: 0.9897 - lr: 1.0000e-05 Epoch 9/12 230/235 [============================>.] - ETA: 0s - loss: 0.0394 - accuracy: 0.9896 Learning rate for epoch 9 is 9.999999747378752e-06 235/235 [==============================] - 2s 7ms/step - loss: 0.0393 - accuracy: 0.9897 - lr: 1.0000e-05 Epoch 10/12 229/235 [============================>.] - ETA: 0s - loss: 0.0393 - accuracy: 0.9897 Learning rate for epoch 10 is 9.999999747378752e-06 235/235 [==============================] - 2s 7ms/step - loss: 0.0391 - accuracy: 0.9898 - lr: 1.0000e-05 Epoch 11/12 228/235 [============================>.] - ETA: 0s - loss: 0.0390 - accuracy: 0.9898 Learning rate for epoch 11 is 9.999999747378752e-06 235/235 [==============================] - 2s 8ms/step - loss: 0.0390 - accuracy: 0.9898 - lr: 1.0000e-05 Epoch 12/12 232/235 [============================>.] - ETA: 0s - loss: 0.0387 - accuracy: 0.9900 Learning rate for epoch 12 is 9.999999747378752e-06 235/235 [==============================] - 2s 7ms/step - loss: 0.0388 - accuracy: 0.9900 - lr: 1.0000e-05 <keras.src.callbacks.History at 0x7f33cc2725e0>
查看保存的检查点:
# Check the checkpoint directory.
ls {checkpoint_dir}
checkpoint ckpt_4.data-00000-of-00001 ckpt_1.data-00000-of-00001 ckpt_4.index ckpt_1.index ckpt_5.data-00000-of-00001 ckpt_10.data-00000-of-00001 ckpt_5.index ckpt_10.index ckpt_6.data-00000-of-00001 ckpt_11.data-00000-of-00001 ckpt_6.index ckpt_11.index ckpt_7.data-00000-of-00001 ckpt_12.data-00000-of-00001 ckpt_7.index ckpt_12.index ckpt_8.data-00000-of-00001 ckpt_2.data-00000-of-00001 ckpt_8.index ckpt_2.index ckpt_9.data-00000-of-00001 ckpt_3.data-00000-of-00001 ckpt_9.index ckpt_3.index
要查看模型的执行情况,请加载最新的检查点,并在测试数据上调用 Model.evaluate
:
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc))
2023-11-07 23:34:58.383301: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. 40/40 [==============================] - 2s 8ms/step - loss: 0.0501 - accuracy: 0.9831 Eval loss: 0.05011461675167084, Eval accuracy: 0.9830999970436096
要可视化输出,请启动 TensorBoard 并查看日志:
%tensorboard --logdir=logs
ls -sh ./logs
total 4.0K 4.0K train
保存模型
使用 Model.save
将模型保存到一个 .keras
压缩归档中。保存后,您可以使用或不使用 Strategy.scope
加载模型。
path = 'my_model.keras'
model.save(path)
现在,在没有 Strategy.scope
的情况下加载模型:
unreplicated_model = tf.keras.models.load_model(path)
unreplicated_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
40/40 [==============================] - 0s 3ms/step - loss: 0.0501 - accuracy: 0.9831 Eval loss: 0.05011461302638054, Eval Accuracy: 0.9830999970436096
使用 Strategy.scope
加载模型:
with strategy.scope():
replicated_model = tf.keras.models.load_model(path)
replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
2023-11-07 23:35:01.980132: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. 40/40 [==============================] - 2s 5ms/step - loss: 0.0501 - accuracy: 0.9831 Eval loss: 0.05011461675167084, Eval Accuracy: 0.9830999970436096
其他资源
更多通过 Keras Model.fit
API 使用不同分布策略的示例:
- 在 TPU 上使用 BERT 解决 GLUE 任务教程使用
tf.distribute.MirroredStrategy
在 GPU 上进行训练,并使用tf.distribute.TPUStrategy
在 TPU 上进行训练。 - 使用分布式策略保存和加载模型教程演示了如何将 SavedModel API 与
tf.distribute.Strategy
一起使用。 - 官方 TensorFlow 模型可以配置为运行多个分布式策略。
要了解有关 TensorFlow 分布式策略的更多信息,请参阅以下资料:
- 使用 tf.distribute.Strategy 进行自定义训练教程展示了如何使用
tf.distribute.MirroredStrategy
通过自定义训练循环进行单工作进程训练。 - 使用 Keras 进行多工作进程训练教程展示了如何将
MultiWorkerMirroredStrategy
与Model.fit
一起使用。 - 使用 Keras 和 MultiWorkerMirroredStrategy 的自定义训练循环教程展示了如何将
MultiWorkerMirroredStrategy
与 Keras 和自定义训练循环一起使用。 - TensorFlow 中的分布式训练指南概述了可用的分布式策略。
- 使用 tf.function 获得更佳性能指南提供了有关其他策略和工具的信息,例如可用于优化 TensorFlow 模型性能的 TensorFlow Profiler。
注:tf.distribute.Strategy
正在积极开发中,TensorFlow 将在不久的将来添加更多示例和教程。请进行尝试。我们欢迎您通过 GitHub 上的议题提交反馈。