在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 Github 上查看源代码 | 下载笔记本 |
概述
本教程演示如何使用 Keras 模型和 tf.distribute.Strategy
API 的自定义训练循环执行多工作进程分布式训练。训练循环通过 tf.distribute.MultiWorkerMirroredStrategy
进行分布。这样,设计为在单个工作进程上运行的 tf.keras
模型即可通过最少的代码更改无缝地在多个工作进程上运行。自定义训练循环提供了灵活性和更好的训练控制,同时也使模型的调试更加容易。请详细了解有关编写基本训练循环、 从头开始编写训练循环和自定义训练的信息。
如果您正在寻找如何将 MultiWorkerMirroredStrategy
与 tf.keras.Model.fit
一起使用,请参阅此教程。
TensorFlow 中的分布式训练指南概述了 TensorFlow 支持的分布式策略,并适用于想要更深入了解 tf.distribute.Strategy
API 的人。
安装
首先,进行一些必要的导入。
import json
import os
import sys
在导入 TensorFlow 之前,需要对环境进行一些变更:
- 停用所有 GPU。这可以防止所有工作进程都尝试使用同一个 GPU 而导致的错误。对于真实应用,每个工作进程都将在不同的计算机上运行。
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
- 重置
'TF_CONFIG'
环境变量(稍后您将看到更多相关信息)。
os.environ.pop('TF_CONFIG', None)
- 确保当前目录位于 Python 的路径上。这样,笔记本可以导入稍后由
%%writefile
写入的文件。
if '.' not in sys.path:
sys.path.insert(0, '.')
现在导入 TensorFlow。
import tensorflow as tf
2023-11-07 23:15:16.986830: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-07 23:15:16.986881: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-07 23:15:16.988638: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
数据集和模型定义
接下来,使用简单的模型和数据集设置创建 mnist.py
文件。本教程中的工作进程将使用此 Python 文件:
%%writefile mnist.py
import os
import tensorflow as tf
import numpy as np
def mnist_dataset(batch_size):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
# The `x` arrays are in uint8 and have values in the range [0, 255].
# You need to convert them to float32 with values in the range [0, 1]
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(60000)
return train_dataset
def dataset_fn(global_batch_size, input_context):
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = mnist_dataset(batch_size)
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
return dataset
def build_cnn_model():
return tf.keras.Sequential([
tf.keras.Input(shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
Writing mnist.py
多工作进程配置
接下来,我们进入多工作进程训练的世界。在 TensorFlow 中,在多台计算机上进行训练需要 'TF_CONFIG'
环境变量。每台计算机可能有不同的角色。下面使用的 'TF_CONFIG'
变量是一个 JSON 字符串,它指定集群中每个工作进程的集群配置。这是使用 cluster_resolver.TFConfigClusterResolver
指定集群的默认方法,但在 distribute.cluster_resolver
模块中还有其他可用选项。请在分布式训练指南中了解有关设置 'TF_CONFIG'
变量的更多信息。
描述您的集群
下面是一个示例配置:
tf_config = {
'cluster': {
'worker': ['localhost:12345', 'localhost:23456']
},
'task': {'type': 'worker', 'index': 0}
}
请注意,tf_config
只是 Python 中的局部变量。要将其用于训练配置,请将其序列化为 JSON 并将其放在 'TF_CONFIG'
环境变量中。这是序列化为 JSON 字符串的相同 'TF_CONFIG'
:
json.dumps(tf_config)
'{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'
'TF_CONFIG'
有两个组件:'cluster'
和 'task'
。
'cluster'
对所有工作进程都相同,并提供有关训练集群的信息,这是一个由不同类型的作业组成的字典,例如'worker'
。在使用MultiWorkerMirroredStrategy
进行的多工作进程训练中,除了普通的'worker'
之外,通常还有一个'worker'
承担更多的责任,例如保存检查点和为 TensorBoard 编写摘要文件。这样的工作进程被称为'chief'
工作进程,习惯上将'index'
为 0 的'worker'
指定为首席worker
。'task'
提供当前任务的信息,并且在每个工作进程上都不相同。它指定该工作进程的'type'
和'index'
。
在本例中,您会将任务 'type'
设置为 'worker'
,将任务 'index'
设置为 0
。这台计算机是首个工作进程,将被指定为首席工作进程,并需要比其他工作进程承担更多的工作。请注意,其他计算机也需要设置 'TF_CONFIG'
环境变量,且应该具有相同的 'cluster'
字典,但要根据这些计算机的具体角色来设置不同的任务 'type'
或任务 'index'
。
出于演示的目的,本教程将展示如何在 'localhost'
上设置具有两个工作进程的 'TF_CONFIG'
。在实践中,用户会在外部 IP 地址/端口上创建多个工作进程,并为每个工作进程正确设置 'TF_CONFIG'
。
本示例使用两个工作进程,第一个工作进程的 'TF_CONFIG'
如上所示。对于第二个工作进程,设置 tf_config['task']['index']=1
。
笔记本中的环境变量和子进程
子进程会从其父进程继承环境变量。因此,如果您在此 Jupyter Notebook 进程中设置环境变量:
os.environ['GREETINGS'] = 'Hello TensorFlow!'
然后,您可以从子进程访问环境变量:
echo ${GREETINGS}
Hello TensorFlow!
在下一部分中,您将使用它来将 'TF_CONFIG'
传递给工作进程子进程。实际上,您永远不会以这种方式启动您的作业,但这完全可以满足此教程的演示目的:呈现最简单的多工作进程示例。
MultiWorkerMirroredStrategy
在训练模型之前,首先创建一个 tf.distribute.MultiWorkerMirroredStrategy
的实例:
strategy = tf.distribute.MultiWorkerMirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/device:CPU:0',) INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CommunicationImplementation.AUTO 2023-11-07 23:15:19.298847: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
注:在您调用 tf.distribute.MultiWorkerMirroredStrategy
时,会解析 'TF_CONFIG'
并启动 TensorFlow 的 GRPC 服务器。因此,您必须在实例化 tf.distribute.Strategy
之前设置 'TF_CONFIG'
环境变量。为了在这个说明性示例中节省时间,本教程中没有对此进行演示,因此不需要启动服务器。您可以在本教程的最后一个部分中找到完整的示例。
使用 tf.distribute.Strategy.scope
指定构建模型时应使用的策略。这使得该策略可以控制变量放置之类的事情,它将在所有工作进程的每个设备上,在模型的层中创建所有变量的副本。
import mnist
with strategy.scope():
# Model building needs to be within `strategy.scope()`.
multi_worker_model = mnist.build_cnn_model()
在工作进程之间对数据进行自动分片
在多工作进程训练中,需要通过数据集分片来确保收敛性和可重复性。分片意味着将整个数据集的一个子集交给每个工作进程,这有助于创造类似于对单个工作进程进行训练的体验。在下面的示例中,您依赖于 tf.distribute
的默认自动分片策略。您还可以通过设置 tf.data.experimental.DistributeOptions
的 tf.data.experimental.AutoShardPolicy
来对其进行自定义。要了解更多信息,请参阅分布式输入教程的分片部分。
per_worker_batch_size = 64
num_workers = len(tf_config['cluster']['worker'])
global_batch_size = per_worker_batch_size * num_workers
with strategy.scope():
multi_worker_dataset = strategy.distribute_datasets_from_function(
lambda input_context: mnist.dataset_fn(global_batch_size, input_context))
定义自定义训练循环并训练模型
指定优化器:
with strategy.scope():
# The creation of optimizer and train_accuracy needs to be in
# `strategy.scope()` as well, since they create variables.
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='train_accuracy')
使用 tf.function
定义训练步骤:
@tf.function
def train_step(iterator):
"""Training step function."""
def step_fn(inputs):
"""Per-Replica step function."""
x, y = inputs
with tf.GradientTape() as tape:
predictions = multi_worker_model(x, training=True)
per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE)(y, predictions)
loss = tf.nn.compute_average_loss(
per_batch_loss, global_batch_size=global_batch_size)
grads = tape.gradient(loss, multi_worker_model.trainable_variables)
optimizer.apply_gradients(
zip(grads, multi_worker_model.trainable_variables))
train_accuracy.update_state(y, predictions)
return loss
per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
return strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
检查点保存和恢复
在编写自定义训练循环时,您需要手动处理检查点保存,而不是依赖 Keras 回调。请注意,对于 MultiWorkerMirroredStrategy
,保存检查点或完整模型需要所有工作进程的参与,因为尝试仅在首席工作进程上进行保存可能会导致死锁。工作进程还需要写入不同的路径以避免相互重写。以下是如何配置目录的示例:
from multiprocessing import util
checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')
def _is_chief(task_type, task_id, cluster_spec):
return (task_type is None
or task_type == 'chief'
or (task_type == 'worker'
and task_id == 0
and "chief" not in cluster_spec.as_dict()))
def _get_temp_dir(dirpath, task_id):
base_dirpath = 'workertemp_' + str(task_id)
temp_dir = os.path.join(dirpath, base_dirpath)
tf.io.gfile.makedirs(temp_dir)
return temp_dir
def write_filepath(filepath, task_type, task_id, cluster_spec):
dirpath = os.path.dirname(filepath)
base = os.path.basename(filepath)
if not _is_chief(task_type, task_id, cluster_spec):
dirpath = _get_temp_dir(dirpath, task_id)
return os.path.join(dirpath, base)
创建一个跟踪模型的 tf.train.Checkpoint
,由 tf.train.CheckpointManager
管理,以便仅保留最新的检查点:
epoch = tf.Variable(
initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')
step_in_epoch = tf.Variable(
initial_value=tf.constant(0, dtype=tf.dtypes.int64),
name='step_in_epoch')
task_type, task_id = (strategy.cluster_resolver.task_type,
strategy.cluster_resolver.task_id)
# Normally, you don't need to manually instantiate a `ClusterSpec`, but in this
# illustrative example you did not set `'TF_CONFIG'` before initializing the
# strategy. Check out the next section for "real-world" usage.
cluster_spec = tf.train.ClusterSpec(tf_config['cluster'])
checkpoint = tf.train.Checkpoint(
model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id,
cluster_spec)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
现在,当需要恢复检查点时,您可以方便地使用 tf.train.latest_checkpoint
函数(或通过调用 tf.train.CheckpointManager.restore_or_initialize
)找到最新的已保存检查点。
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
恢复检查点后,您可以继续训练自定义训练循环。
num_epochs = 3
num_steps_per_epoch = 70
while epoch.numpy() < num_epochs:
iterator = iter(multi_worker_dataset)
total_loss = 0.0
num_batches = 0
while step_in_epoch.numpy() < num_steps_per_epoch:
total_loss += train_step(iterator)
num_batches += 1
step_in_epoch.assign_add(1)
train_loss = total_loss / num_batches
print('Epoch: %d, accuracy: %f, train_loss: %f.'
%(epoch.numpy(), train_accuracy.result(), train_loss))
train_accuracy.reset_states()
# Once the `CheckpointManager` is set up, you're now ready to save, and remove
# the checkpoints non-chief workers saved.
checkpoint_manager.save()
if not _is_chief(task_type, task_id, cluster_spec):
tf.io.gfile.rmtree(write_checkpoint_dir)
epoch.assign_add(1)
step_in_epoch.assign(0)
2023-11-07 23:15:20.366756: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations. Epoch: 0, accuracy: 0.807366, train_loss: 0.621664. Epoch: 1, accuracy: 0.926786, train_loss: 0.255375. Epoch: 2, accuracy: 0.947656, train_loss: 0.172921.
完整代码一览
总结一下到目前为止讨论的所有程序:
- 创建工作进程。
- 将
'TF_CONFIG'
传递给工作进程。 - 让每个工作进程运行下面包含训练代码的脚本。
File: main.py
%%writefile main.py
import os
import json
import tensorflow as tf
import mnist
from multiprocessing import util
per_worker_batch_size = 64
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])
global_batch_size = per_worker_batch_size * num_workers
num_epochs = 3
num_steps_per_epoch=70
# Checkpoint saving and restoring
def _is_chief(task_type, task_id, cluster_spec):
return (task_type is None
or task_type == 'chief'
or (task_type == 'worker'
and task_id == 0
and 'chief' not in cluster_spec.as_dict()))
def _get_temp_dir(dirpath, task_id):
base_dirpath = 'workertemp_' + str(task_id)
temp_dir = os.path.join(dirpath, base_dirpath)
tf.io.gfile.makedirs(temp_dir)
return temp_dir
def write_filepath(filepath, task_type, task_id, cluster_spec):
dirpath = os.path.dirname(filepath)
base = os.path.basename(filepath)
if not _is_chief(task_type, task_id, cluster_spec):
dirpath = _get_temp_dir(dirpath, task_id)
return os.path.join(dirpath, base)
checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')
# Define Strategy
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
# Model building/compiling need to be within `tf.distribute.Strategy.scope`.
multi_worker_model = mnist.build_cnn_model()
multi_worker_dataset = strategy.distribute_datasets_from_function(
lambda input_context: mnist.dataset_fn(global_batch_size, input_context))
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='train_accuracy')
@tf.function
def train_step(iterator):
"""Training step function."""
def step_fn(inputs):
"""Per-Replica step function."""
x, y = inputs
with tf.GradientTape() as tape:
predictions = multi_worker_model(x, training=True)
per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE)(y, predictions)
loss = tf.nn.compute_average_loss(
per_batch_loss, global_batch_size=global_batch_size)
grads = tape.gradient(loss, multi_worker_model.trainable_variables)
optimizer.apply_gradients(
zip(grads, multi_worker_model.trainable_variables))
train_accuracy.update_state(y, predictions)
return loss
per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
return strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
epoch = tf.Variable(
initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')
step_in_epoch = tf.Variable(
initial_value=tf.constant(0, dtype=tf.dtypes.int64),
name='step_in_epoch')
task_type, task_id, cluster_spec = (strategy.cluster_resolver.task_type,
strategy.cluster_resolver.task_id,
strategy.cluster_resolver.cluster_spec())
checkpoint = tf.train.Checkpoint(
model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id,
cluster_spec)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
# Restoring the checkpoint
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
# Resume our CTL training
while epoch.numpy() < num_epochs:
iterator = iter(multi_worker_dataset)
total_loss = 0.0
num_batches = 0
while step_in_epoch.numpy() < num_steps_per_epoch:
total_loss += train_step(iterator)
num_batches += 1
step_in_epoch.assign_add(1)
train_loss = total_loss / num_batches
print('Epoch: %d, accuracy: %f, train_loss: %f.'
%(epoch.numpy(), train_accuracy.result(), train_loss))
train_accuracy.reset_states()
checkpoint_manager.save()
if not _is_chief(task_type, task_id, cluster_spec):
tf.io.gfile.rmtree(write_checkpoint_dir)
epoch.assign_add(1)
step_in_epoch.assign(0)
Writing main.py
当前目录现包含两个 Python 文件:
ls *.py
main.py mnist.py
因此,对 'TF_CONFIG'
执行 JSON 序列化,然后将其添加到环境变量:
os.environ['TF_CONFIG'] = json.dumps(tf_config)
现在,您可以启动一个将运行 main.py
并使用 'TF_CONFIG'
的工作进程:
# first kill any previous runs
%killbgscripts
All background processes were killed.
python main.py &> job_0.log
以上命令有几点需要注意:
- 它使用
%%bash
,这是一项用于运行一些 bash 命令的笔记本“魔术命令”。 - 它使用
--bg
标志在后台运行bash
进程,因为此工作进程不会终止。它在开始之前会等待所有工作进程。
后台工作进程不会将输出打印到此笔记本。&>
会将其输出重定向到一个文件,以便您可以查看所发生的情况。
等待几秒钟以启动该进程:
import time
time.sleep(20)
接下来,检查一下目前为止输出到工作进程日志文件的内容:
cat job_0.log
2023-11-07 23:15:24.897952: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-07 23:15:24.898016: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-07 23:15:24.899709: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2023-11-07 23:15:27.043487: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
日志文件的最后一行内容应为:Started server with target: grpc://localhost:12345
。第一个工作进程现已准备就绪,正在等待所有其他工作进程准备就绪以继续。
更新 tf_config
以供第二个工作进程取用:
tf_config['task']['index'] = 1
os.environ['TF_CONFIG'] = json.dumps(tf_config)
现在,启动第二个工作进程。这将开始训练,因为所有工作进程都已处于活动状态(因此无需在后台执行此进程):
python main.py > /dev/null 2>&1
如果您重新检查第一个工作进程编写的日志,您会看到它参与了该模型的训练:
cat job_0.log
2023-11-07 23:15:24.897952: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-07 23:15:24.898016: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-07 23:15:24.899709: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2023-11-07 23:15:27.043487: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected 2023-11-07 23:15:48.287770: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations. Epoch: 0, accuracy: 0.804129, train_loss: 0.624825. Epoch: 1, accuracy: 0.920201, train_loss: 0.276320. Epoch: 2, accuracy: 0.946429, train_loss: 0.194815.
# Delete the `'TF_CONFIG'`, and kill any background tasks so they don't affect the next section.
os.environ.pop('TF_CONFIG', None)
%killbgscripts
All background processes were killed.
深入了解多工作进程训练
本教程演示了多工作进程设置的自定义训练循环工作流程。有关其他主题的详细描述可在适用于自定义训练循环的使用 Keras 进行多工作进程训练 (tf.keras.Model.fit
) 教程中找到。
了解更多
- TensorFlow 中的分布式训练指南概述了可用的分布式策略。
- 官方模型,其中许多模型可以配置为运行多个分布式策略。
tf.function
指南中的“性能”部分提供了有关其他策略和工具的信息,您可以使用它们来优化 TensorFlow 模型的性能。