使用 Keras 和 MultiWorkerMirroredStrategy 的自定义训练循环

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本

概述

本教程演示如何使用 Keras 模型和 tf.distribute.Strategy API 的自定义训练循环执行多工作进程分布式训练。训练循环通过 tf.distribute.MultiWorkerMirroredStrategy 进行分布。这样,设计为在单个工作进程上运行的 tf.keras 模型即可通过最少的代码更改无缝地在多个工作进程上运行。自定义训练循环提供了灵活性和更好的训练控制,同时也使模型的调试更加容易。请详细了解有关编写基本训练循环从头开始编写训练循环自定义训练的信息。

如果您正在寻找如何将 MultiWorkerMirroredStrategytf.keras.Model.fit 一起使用,请参阅此教程

TensorFlow 中的分布式训练指南概述了 TensorFlow 支持的分布式策略,并适用于想要更深入了解 tf.distribute.Strategy API 的人。

安装

首先,进行一些必要的导入。

import json
import os
import sys

在导入 TensorFlow 之前,需要对环境进行一些变更:

  • 停用所有 GPU。这可以防止所有工作进程都尝试使用同一个 GPU 而导致的错误。对于真实应用,每个工作进程都将在不同的计算机上运行。
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  • 重置 'TF_CONFIG' 环境变量(稍后您将看到更多相关信息)。
os.environ.pop('TF_CONFIG', None)
  • 确保当前目录位于 Python 的路径上。这样,笔记本可以导入稍后由 %%writefile 写入的文件。
if '.' not in sys.path:
  sys.path.insert(0, '.')

现在导入 TensorFlow。

import tensorflow as tf
2023-11-07 23:15:16.986830: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-07 23:15:16.986881: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-07 23:15:16.988638: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

数据集和模型定义

接下来,使用简单的模型和数据集设置创建 mnist.py 文件。本教程中的工作进程将使用此 Python 文件:

%%writefile mnist.py

import os
import tensorflow as tf
import numpy as np

def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # The `x` arrays are in uint8 and have values in the range [0, 255].
  # You need to convert them to float32 with values in the range [0, 1]
  x_train = x_train / np.float32(255)
  y_train = y_train.astype(np.int64)
  train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000)
  return train_dataset

def dataset_fn(global_batch_size, input_context):
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)
  dataset = mnist_dataset(batch_size)
  dataset = dataset.shard(input_context.num_input_pipelines,
                          input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  return dataset

def build_cnn_model():
  return tf.keras.Sequential([
      tf.keras.Input(shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
Writing mnist.py

多工作进程配置

接下来,我们进入多工作进程训练的世界。在 TensorFlow 中,在多台计算机上进行训练需要 'TF_CONFIG' 环境变量。每台计算机可能有不同的角色。下面使用的 'TF_CONFIG' 变量是一个 JSON 字符串,它指定集群中每个工作进程的集群配置。这是使用 cluster_resolver.TFConfigClusterResolver 指定集群的默认方法,但在 distribute.cluster_resolver 模块中还有其他可用选项。请在分布式训练指南中了解有关设置 'TF_CONFIG' 变量的更多信息。

描述您的集群

下面是一个示例配置:

tf_config = {
    'cluster': {
        'worker': ['localhost:12345', 'localhost:23456']
    },
    'task': {'type': 'worker', 'index': 0}
}

请注意,tf_config 只是 Python 中的局部变量。要将其用于训练配置,请将其序列化为 JSON 并将其放在 'TF_CONFIG' 环境变量中。这是序列化为 JSON 字符串的相同 'TF_CONFIG'

json.dumps(tf_config)
'{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'

'TF_CONFIG' 有两个组件:'cluster''task'

  • 'cluster' 对所有工作进程都相同,并提供有关训练集群的信息,这是一个由不同类型的作业组成的字典,例如 'worker' 。在使用 MultiWorkerMirroredStrategy 进行的多工作进程训练中,除了普通的 'worker' 之外,通常还有一个 'worker' 承担更多的责任,例如保存检查点和为 TensorBoard 编写摘要文件。这样的工作进程被称为 'chief' 工作进程,习惯上将 'index' 为 0 的 'worker' 指定为首席 worker

  • 'task' 提供当前任务的信息,并且在每个工作进程上都不相同。它指定该工作进程的 'type''index'

在本例中,您会将任务 'type' 设置为 'worker',将任务 'index' 设置为 0。这台计算机是首个工作进程,将被指定为首席工作进程,并需要比其他工作进程承担更多的工作。请注意,其他计算机也需要设置 'TF_CONFIG' 环境变量,且应该具有相同的 'cluster' 字典,但要根据这些计算机的具体角色来设置不同的任务 'type' 或任务 'index'

出于演示的目的,本教程将展示如何在 'localhost' 上设置具有两个工作进程的 'TF_CONFIG'。在实践中,用户会在外部 IP 地址/端口上创建多个工作进程,并为每个工作进程正确设置 'TF_CONFIG'

本示例使用两个工作进程,第一个工作进程的 'TF_CONFIG' 如上所示。对于第二个工作进程,设置 tf_config['task']['index']=1

笔记本中的环境变量和子进程

子进程会从其父进程继承环境变量。因此,如果您在此 Jupyter Notebook 进程中设置环境变量:

os.environ['GREETINGS'] = 'Hello TensorFlow!'

然后,您可以从子进程访问环境变量:

echo ${GREETINGS}
Hello TensorFlow!

在下一部分中,您将使用它来将 'TF_CONFIG' 传递给工作进程子进程。实际上,您永远不会以这种方式启动您的作业,但这完全可以满足此教程的演示目的:呈现最简单的多工作进程示例。

MultiWorkerMirroredStrategy

在训练模型之前,首先创建一个 tf.distribute.MultiWorkerMirroredStrategy 的实例:

strategy = tf.distribute.MultiWorkerMirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/device:CPU:0',)
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CommunicationImplementation.AUTO
2023-11-07 23:15:19.298847: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

注:在您调用 tf.distribute.MultiWorkerMirroredStrategy 时,会解析 'TF_CONFIG' 并启动 TensorFlow 的 GRPC 服务器。因此,您必须在实例化 tf.distribute.Strategy 之前设置 'TF_CONFIG' 环境变量。为了在这个说明性示例中节省时间,本教程中没有对此进行演示,因此不需要启动服务器。您可以在本教程的最后一个部分中找到完整的示例。

使用 tf.distribute.Strategy.scope 指定构建模型时应使用的策略。这使得该策略可以控制变量放置之类的事情,它将在所有工作进程的每个设备上,在模型的层中创建所有变量的副本。

import mnist
with strategy.scope():
  # Model building needs to be within `strategy.scope()`.
  multi_worker_model = mnist.build_cnn_model()

在工作进程之间对数据进行自动分片

在多工作进程训练中,需要通过数据集分片来确保收敛性和可重复性。分片意味着将整个数据集的一个子集交给每个工作进程,这有助于创造类似于对单个工作进程进行训练的体验。在下面的示例中,您依赖于 tf.distribute 的默认自动分片策略。您还可以通过设置 tf.data.experimental.DistributeOptionstf.data.experimental.AutoShardPolicy 来对其进行自定义。要了解更多信息,请参阅分布式输入教程分片部分。

per_worker_batch_size = 64
num_workers = len(tf_config['cluster']['worker'])
global_batch_size = per_worker_batch_size * num_workers

with strategy.scope():
  multi_worker_dataset = strategy.distribute_datasets_from_function(
      lambda input_context: mnist.dataset_fn(global_batch_size, input_context))

定义自定义训练循环并训练模型

指定优化器:

with strategy.scope():
  # The creation of optimizer and train_accuracy needs to be in
  # `strategy.scope()` as well, since they create variables.
  optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')

使用 tf.function 定义训练步骤:

@tf.function
def train_step(iterator):
  """Training step function."""

  def step_fn(inputs):
    """Per-Replica step function."""
    x, y = inputs
    with tf.GradientTape() as tape:
      predictions = multi_worker_model(x, training=True)
      per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy(
          from_logits=True,
          reduction=tf.keras.losses.Reduction.NONE)(y, predictions)
      loss = tf.nn.compute_average_loss(
          per_batch_loss, global_batch_size=global_batch_size)

    grads = tape.gradient(loss, multi_worker_model.trainable_variables)
    optimizer.apply_gradients(
        zip(grads, multi_worker_model.trainable_variables))
    train_accuracy.update_state(y, predictions)
    return loss

  per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
  return strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

检查点保存和恢复

在编写自定义训练循环时,您需要手动处理检查点保存,而不是依赖 Keras 回调。请注意,对于 MultiWorkerMirroredStrategy,保存检查点或完整模型需要所有工作进程的参与,因为尝试仅在首席工作进程上进行保存可能会导致死锁。工作进程还需要写入不同的路径以避免相互重写。以下是如何配置目录的示例:

from multiprocessing import util
checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')

def _is_chief(task_type, task_id, cluster_spec):
  return (task_type is None
          or task_type == 'chief'
          or (task_type == 'worker'
              and task_id == 0
              and "chief" not in cluster_spec.as_dict()))

def _get_temp_dir(dirpath, task_id):
  base_dirpath = 'workertemp_' + str(task_id)
  temp_dir = os.path.join(dirpath, base_dirpath)
  tf.io.gfile.makedirs(temp_dir)
  return temp_dir

def write_filepath(filepath, task_type, task_id, cluster_spec):
  dirpath = os.path.dirname(filepath)
  base = os.path.basename(filepath)
  if not _is_chief(task_type, task_id, cluster_spec):
    dirpath = _get_temp_dir(dirpath, task_id)
  return os.path.join(dirpath, base)

创建一个跟踪模型的 tf.train.Checkpoint,由 tf.train.CheckpointManager 管理,以便仅保留最新的检查点:

epoch = tf.Variable(
    initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')
step_in_epoch = tf.Variable(
    initial_value=tf.constant(0, dtype=tf.dtypes.int64),
    name='step_in_epoch')
task_type, task_id = (strategy.cluster_resolver.task_type,
                      strategy.cluster_resolver.task_id)
# Normally, you don't need to manually instantiate a `ClusterSpec`, but in this 
# illustrative example you did not set `'TF_CONFIG'` before initializing the
# strategy. Check out the next section for "real-world" usage.
cluster_spec = tf.train.ClusterSpec(tf_config['cluster'])

checkpoint = tf.train.Checkpoint(
    model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)

write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id,
                                      cluster_spec)
checkpoint_manager = tf.train.CheckpointManager(
    checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

现在,当需要恢复检查点时,您可以方便地使用 tf.train.latest_checkpoint 函数(或通过调用 tf.train.CheckpointManager.restore_or_initialize )找到最新的已保存检查点。

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
  checkpoint.restore(latest_checkpoint)

恢复检查点后,您可以继续训练自定义训练循环。

num_epochs = 3
num_steps_per_epoch = 70

while epoch.numpy() < num_epochs:
  iterator = iter(multi_worker_dataset)
  total_loss = 0.0
  num_batches = 0

  while step_in_epoch.numpy() < num_steps_per_epoch:
    total_loss += train_step(iterator)
    num_batches += 1
    step_in_epoch.assign_add(1)

  train_loss = total_loss / num_batches
  print('Epoch: %d, accuracy: %f, train_loss: %f.'
                %(epoch.numpy(), train_accuracy.result(), train_loss))

  train_accuracy.reset_states()

  # Once the `CheckpointManager` is set up, you're now ready to save, and remove
  # the checkpoints non-chief workers saved.
  checkpoint_manager.save()
  if not _is_chief(task_type, task_id, cluster_spec):
    tf.io.gfile.rmtree(write_checkpoint_dir)

  epoch.assign_add(1)
  step_in_epoch.assign(0)
2023-11-07 23:15:20.366756: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
Epoch: 0, accuracy: 0.807366, train_loss: 0.621664.
Epoch: 1, accuracy: 0.926786, train_loss: 0.255375.
Epoch: 2, accuracy: 0.947656, train_loss: 0.172921.

完整代码一览

总结一下到目前为止讨论的所有程序:

  1. 创建工作进程。
  2. 'TF_CONFIG' 传递给工作进程。
  3. 让每个工作进程运行下面包含训练代码的脚本。

File: main.py

Writing main.py

当前目录现包含两个 Python 文件:

ls *.py
main.py
mnist.py

因此,对 'TF_CONFIG' 执行 JSON 序列化,然后将其添加到环境变量:

os.environ['TF_CONFIG'] = json.dumps(tf_config)

现在,您可以启动一个将运行 main.py 并使用 'TF_CONFIG' 的工作进程:

# first kill any previous runs
%killbgscripts
All background processes were killed.
python main.py &> job_0.log

以上命令有几点需要注意:

  1. 它使用 %%bash,这是一项用于运行一些 bash 命令的笔记本“魔术命令”
  2. 它使用 --bg 标志在后台运行 bash 进程,因为此工作进程不会终止。它在开始之前会等待所有工作进程。

后台工作进程不会将输出打印到此笔记本。&> 会将其输出重定向到一个文件,以便您可以查看所发生的情况。

等待几秒钟以启动该进程:

import time
time.sleep(20)

接下来,检查一下目前为止输出到工作进程日志文件的内容:

cat job_0.log
2023-11-07 23:15:24.897952: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-07 23:15:24.898016: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-07 23:15:24.899709: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-07 23:15:27.043487: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

日志文件的最后一行内容应为:Started server with target: grpc://localhost:12345。第一个工作进程现已准备就绪,正在等待所有其他工作进程准备就绪以继续。

更新 tf_config 以供第二个工作进程取用:

tf_config['task']['index'] = 1
os.environ['TF_CONFIG'] = json.dumps(tf_config)

现在,启动第二个工作进程。这将开始训练,因为所有工作进程都已处于活动状态(因此无需在后台执行此进程):

python main.py > /dev/null 2>&1

如果您重新检查第一个工作进程编写的日志,您会看到它参与了该模型的训练:

cat job_0.log
2023-11-07 23:15:24.897952: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-07 23:15:24.898016: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-07 23:15:24.899709: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-07 23:15:27.043487: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2023-11-07 23:15:48.287770: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
Epoch: 0, accuracy: 0.804129, train_loss: 0.624825.
Epoch: 1, accuracy: 0.920201, train_loss: 0.276320.
Epoch: 2, accuracy: 0.946429, train_loss: 0.194815.
# Delete the `'TF_CONFIG'`, and kill any background tasks so they don't affect the next section.
os.environ.pop('TF_CONFIG', None)
%killbgscripts
All background processes were killed.

深入了解多工作进程训练

本教程演示了多工作进程设置的自定义训练循环工作流程。有关其他主题的详细描述可在适用于自定义训练循环的使用 Keras 进行多工作进程训练 (tf.keras.Model.fit) 教程中找到。

了解更多

  1. TensorFlow 中的分布式训练指南概述了可用的分布式策略。
  2. 官方模型,其中许多模型可以配置为运行多个分布式策略。
  3. tf.function 指南中的“性能”部分提供了有关其他策略和工具的信息,您可以使用它们来优化 TensorFlow 模型的性能。