在 tensorflow.google.cn 上查看 | 在 Google Colab 运行 | 在 Github 上查看源代码 | 下载此 notebook |
概述
本教程使用 tf.distribute.Strategy
API 演示了使用 Keras 模型的多工作器(worker)分布式培训。借助专为多工作器(worker)训练而设计的策略,设计在单一工作器(worker)上运行的 Keras 模型可以在最少的代码更改的情况下无缝地处理多个工作器。
TensorFlow 中的分布式培训指南可用于概述TensorFlow支持的分布式策略,并想要更深入理解tf.distribute.Strategy
API 感兴趣的人。
配置
首先,设置 TensorFlow 和必要的导入。
!pip install tf-nightly
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()
准备数据集
现在,让我们从 TensorFlow 数据集 中准备MNIST数据集。 MNIST 数据集 包括60,000个训练样本和10,000个手写数字0-9的测试示例,格式为28x28像素单色图像。
BUFFER_SIZE = 10000
BATCH_SIZE = 64
def make_datasets_unbatched():
# 将 MNIST 数据从 (0, 255] 缩放到 (0., 1.]
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
datasets, info = tfds.load(name='mnist',
with_info=True,
as_supervised=True)
return datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE)
train_datasets = make_datasets_unbatched().batch(BATCH_SIZE)
构建 Keras 模型
在这里,我们使用tf.keras.Sequential
API来构建和编译一个简单的卷积神经网络 Keras 模型,用我们的 MNIST 数据集进行训练。
注意:有关构建 Keras 模型的详细训练说明,请参阅TensorFlow Keras 指南。
def build_and_compile_cnn_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=['accuracy'])
return model
让我们首先尝试用少量的 epoch 来训练模型,并在单个工作器(worker)中观察结果,以确保一切正常。 随着训练的迭代,您应该会看到损失(loss)下降和准确度(accuracy)接近1.0。
single_worker_model = build_and_compile_cnn_model()
single_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)
Epoch 1/3 5/5 [==============================] - 2s 5ms/step - loss: 2.3034 - accuracy: 0.0875 Epoch 2/3 5/5 [==============================] - 0s 4ms/step - loss: 2.3089 - accuracy: 0.0844 Epoch 3/3 5/5 [==============================] - 0s 4ms/step - loss: 2.3126 - accuracy: 0.0875 2022-06-03 20:07:36.608025: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 2022-06-03 20:07:36.610575: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. <keras.callbacks.History at 0x7fc29c18fb20>
多工作器(worker)配置
现在让我们进入多工作器(worker)训练的世界。在 TensorFlow 中,需要 TF_CONFIG
环境变量来训练多台机器,每台机器可能具有不同的角色。 TF_CONFIG
用于指定作为集群一部分的每个 worker 的集群配置。
TF_CONFIG
有两个组件:cluster
和 task
。 cluster
提供有关训练集群的信息,这是一个由不同类型的工作组成的字典,例如 worker
。在多工作器(worker)培训中,除了常规的“工作器”之外,通常还有一个“工人”承担更多责任,比如保存检查点和为 TensorBoard 编写摘要文件。这样的工作器(worker)被称为“主要”工作者,习惯上worker
中 index
0被指定为主要的 worker
(事实上这就是tf.distribute.Strategy
的实现方式)。 另一方面,task
提供当前任务的信息。
在这个例子中,我们将任务 type
设置为 "worker"
并将任务 index
设置为 0
。这意味着具有这种设置的机器是第一个工作器,它将被指定为主要工作器并且要比其他工作器做更多的工作。请注意,其他机器也需要设置 TF_CONFIG
环境变量,它应该具有相同的 cluster
字典,但是不同的任务type
或 index
取决于这些机器的角色。
为了便于说明,本教程展示了如何在 localhost
上设置一个带有2个工作器的TF_CONFIG
。 实际上,用户会在外部IP地址/端口上创建多个工作器,并在每个工作器上适当地设置TF_CONFIG
。
警告:不要在 Colab 中执行以下代码。TensorFlow 的运行时将尝试在指定的IP地址和端口创建 gRPC 服务器,这可能会失败。
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:12345", "localhost:23456"]
},
'task': {'type': 'worker', 'index': 0}
})
注意,虽然在该示例中学习速率是固定的,但是通常可能需要基于全局批量大小来调整学习速率。
选择正确的策略
在 TensorFlow 中,分布式训练包括同步训练(其中训练步骤跨工作器和副本同步)、异步训练(训练步骤未严格同步)。
MultiWorkerMirroredStrategy
是同步多工作器训练的推荐策略,将在本指南中进行演示。
要训练模型,请使用 tf.distribute.experimental.MultiWorkerMirroredStrategy
的实例。 MultiWorkerMirroredStrategy
在所有工作器的每台设备上创建模型层中所有变量的副本。 它使用 CollectiveOps
,一个用于集体通信的 TensorFlow 操作,来聚合梯度并使变量保持同步。 tf.distribute.Strategy
指南有关于此策略的更多详细信息。
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_46979/349189047.py:1: _CollectiveAllReduceStrategyExperimental.__init__ (from tensorflow.python.distribute.collective_all_reduce_strategy) is deprecated and will be removed in a future version. Instructions for updating: use distribute.MultiWorkerMirroredStrategy instead WARNING:tensorflow:From /tmpfs/tmp/ipykernel_46979/349189047.py:1: _CollectiveAllReduceStrategyExperimental.__init__ (from tensorflow.python.distribute.collective_all_reduce_strategy) is deprecated and will be removed in a future version. Instructions for updating: use distribute.MultiWorkerMirroredStrategy instead WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled. WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled. INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.AUTO INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.AUTO
注意:解析 TF_CONFIG
并且在调用 MultiWorkerMirroredStrategy.init()
时启动 TensorFlow 的 GRPC 服务器,因此必须在创建tf.distribute.Strategy
实例之前设置 TF_CONFIG
环境变量。
MultiWorkerMirroredStrategy
通过CollectiveCommunication
参数提供多个实现。RING
使用 gRPC 作为跨主机通信层实现基于环的集合。NCCL
使用Nvidia 的 NCCL来实现集体。 AUTO
将选择推迟到运行时。 集体实现的最佳选择取决于GPU的数量和种类以及群集中的网络互连。
使用 MultiWorkerMirroredStrategy 训练模型
通过将 tf.distribute.Strategy
API集成到 tf.keras
中,将训练分发给多人的唯一更改就是将模型进行构建和 model.compile()
调用封装在 strategy.scope()
内部。 分发策略的范围决定了如何创建变量以及在何处创建变量,对于 MultiWorkerMirroredStrategy 而言,创建的变量为 MirroredVariable ,并且将它们复制到每个工作器上。
注意:在此Colab中,以下代码可以按预期结果运行,但是由于未设置TF_CONFIG
,因此这实际上是单机训练。 在您自己的示例中设置了 TF_CONFIG
后,您应该期望在多台机器上进行培训可以提高速度。
NUM_WORKERS = 2
# 由于 `tf.data.Dataset.batch` 需要全局的批处理大小,
# 因此此处的批处理大小按工作器数量增加。
# 以前我们使用64,现在变成128。
GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS
# 创建数据集需要在 MultiWorkerMirroredStrategy 对象
# 实例化后。
train_datasets = make_datasets_unbatched().batch(GLOBAL_BATCH_SIZE)
with strategy.scope():
# 模型的建立/编译需要在 `strategy.scope()` 内部。
multi_worker_model = build_and_compile_cnn_model()
# Keras 的 `model.fit()` 以特定的时期数和每时期的步数训练模型。
# 注意此处的数量仅用于演示目的,并不足以产生高质量的模型。
multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)
2022-06-03 20:07:38.409119: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. Epoch 1/3 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 5/5 [==============================] - 8s 15ms/step - loss: 2.3121 - accuracy: 0.0656 Epoch 2/3 5/5 [==============================] - 0s 14ms/step - loss: 2.3034 - accuracy: 0.0922 Epoch 3/3 5/5 [==============================] - 0s 14ms/step - loss: 2.3047 - accuracy: 0.1047 2022-06-03 20:07:46.666346: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 2022-06-03 20:07:46.669124: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. <keras.callbacks.History at 0x7fc2842d79d0>
数据集分片和批(batch)大小
在多工作器训练中,需要将数据分片为多个部分,以确保融合和性能。 但是,请注意,在上面的代码片段中,数据集直接发送到model.fit()
,而无需分片; 这是因为tf.distribute.Strategy
API在多工作器训练中会自动处理数据集分片。
如果您喜欢手动分片进行训练,则可以通过tf.data.experimental.DistributeOptions
API关闭自动分片。
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
train_datasets_no_auto_shard = train_datasets.with_options(options)
要注意的另一件事是 datasets
的批处理大小。 在上面的代码片段中,我们使用 GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS
,这是单个工作器的大小的 NUM_WORKERS
倍,因为每个工作器的有效批量大小是全局批量大小(参数从 tf.data.Dataset.batch()
传入)除以工作器的数量,通过此更改,我们使每个工作器的批处理大小与以前相同。
性能
现在,您已经有了一个Keras模型,该模型全部通过 MultiWorkerMirroredStrategy
运行在多个工作器中。 您可以尝试以下技术来调整多工作器训练的效果。
MultiWorkerMirroredStrategy
提供了多个[集体通信实现]collective communication implementations.RING
使用gRPC作为跨主机通信层实现基于环的集合。NCCL
使用 Nvidia's NCCL 来实现集合。AUTO
将推迟到运行时选择。集体实施的最佳选择取决于GPU的数量和种类以及集群中的网络互连。 要覆盖自动选择,请为MultiWorkerMirroredStrategy
的构造函数的communication
参数指定一个有效值,例如:communication=tf.distribute.experimental.CollectiveCommunication.NCCL
.- 如果可能的话,将变量强制转换为
tf.float
。ResNet 的官方模型包括如何完成此操作的示例。
容错能力
在同步训练中,如果其中一个工作器出现故障并且不存在故障恢复机制,则集群将失败。 在工作器退出或不稳定的情况下,将 Keras 与 tf.distribute.Strategy
一起使用会具有容错的优势。 我们通过在您选择的分布式文件系统中保留训练状态来做到这一点,以便在重新启动先前失败或被抢占的实例后,将恢复训练状态。
由于所有工作器在训练 epochs 和 steps 方面保持同步,因此其他工作器将需要等待失败或被抢占的工作器重新启动才能继续。
ModelCheckpoint 回调
要在多工作器训练中利用容错功能,请在调用 tf.keras.Model.fit()
时提供一个 tf.keras.callbacks.ModelCheckpoint
实例。 回调会将检查点和训练状态存储在与 ModelCheckpoint
的 filepath
参数相对应的目录中。
# 将 `filepath` 参数替换为在文件系统中所有工作器都能访问的路径。
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='/tmp/keras-ckpt')]
with strategy.scope():
multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets,
epochs=3,
steps_per_epoch=5,
callbacks=callbacks)
2022-06-03 20:07:46.828733: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. Epoch 1/3 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 5/5 [==============================] - ETA: 0s - loss: 2.3035 - accuracy: 0.1531 WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading. WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.trackable_utils has been moved to tensorflow.python.trackable.trackable_utils. The old module will be deleted in version 2.11. WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.trackable_utils has been moved to tensorflow.python.trackable.trackable_utils. The old module will be deleted in version 2.11. INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets 5/5 [==============================] - 6s 202ms/step - loss: 2.3035 - accuracy: 0.1531 Epoch 2/3 5/5 [==============================] - ETA: 0s - loss: 2.3003 - accuracy: 0.1656 WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets 5/5 [==============================] - 1s 166ms/step - loss: 2.3003 - accuracy: 0.1656 Epoch 3/3 5/5 [==============================] - ETA: 0s - loss: 2.2974 - accuracy: 0.1406 WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets 5/5 [==============================] - 1s 167ms/step - loss: 2.2974 - accuracy: 0.1406 2022-06-03 20:07:54.384936: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 2022-06-03 20:07:54.388366: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. <keras.callbacks.History at 0x7fc2747485b0>
如果某个工作线程被抢占,则整个集群将暂停,直到重新启动被抢占的工作线程为止。工作器重新加入集群后,其他工作器也将重新启动。 现在,每个工作器都将读取先前保存的检查点文件,并获取其以前的状态,从而使群集能够恢复同步,然后继续训练。
如果检查包含在ModelCheckpoint
中指定的 filepath
的目录,则可能会注意到一些临时生成的检查点文件。 这些文件是恢复以前丢失的实例所必需的,并且在成功退出多工作器训练后,这些文件将在 tf.keras.Model.fit()
的末尾被库删除。
您可以查阅
- Distributed Training in TensorFlow 该指南概述了可用的分布式策略。
- ResNet50 官方模型,该模型可以使用
MirroredStrategy
或MultiWorkerMirroredStrategy
进行训练