Keras と MultiWorkerMirroredStrategy を使用したカスタムトレーニングループ

TensorFlow.org で表示 Google Colabで実行 GitHubでソースを表示 ノートブックをダウンロード

概要

このチュートリアルでは、tf.distribute.Strategy API を使用して、Keras モデルとカスタムトレーニングループでマルチワーカー分散トレーニングを実行する方法を実演します。トレーニングループは tf.distribute.MultiWorkerMirroredStrategy を介して分散され、単一のワーカーで実行するように設計された tf.keras モデルが、最小限のコード変更で複数のワーカーでシームレスに機能します。カスタムトレーニングループは、モデルのデバッグを容易にするでけでなく、柔軟なトレーニングとより優れた制御を提供します。詳細については、基本的なトレーニングループの作成ゼロからのトレーニングループの作成カスタムトレーニングを参照してください。

tf.keras.Model.fitMultiWorkerMirroredStrategy を使用する方法については、このチュートリアルを参照してください。

tf.distribute.Strategy API の理解をさらに深めるには、TensorFlow での分散型トレーニングガイドを参照してください。このガイドでは、TensorFlow がサポートする分散ストラテジーの概要が提供されています。

セットアップ

まず、必要なものをインポートします。

import json
import os
import sys

TensorFlow をインポートする前に、環境にいくつかの変更を加えます。

  • すべての GPU を無効にします。これにより、すべてのワーカーが同じ GPU を使用しようとすることによって発生するエラーが防止されます。実際のアプリケーションでは、各ワーカーは異なるマシン上にあります。
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  • 'TF_CONFIG' 環境変数をリセットします。これについては後で詳しく説明します。
os.environ.pop('TF_CONFIG', None)
  • 現在のディレクトリが Python のパス上にあることを確認してください。これにより、ノートブックは後で %%writefile によって書き込まれたファイルをインポートできます。
if '.' not in sys.path:
  sys.path.insert(0, '.')

次に TensorFlow をインポートします。

import tensorflow as tf
2024-01-11 18:07:04.484008: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 18:07:04.484054: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 18:07:04.485585: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

データセットとモデルの定義

次に、単純なモデルとデータセットの設定を使用して mnist.py ファイルを作成します。この Python ファイルは、このチュートリアルのワーカープロセスによって使用されます。

%%writefile mnist.py

import os
import tensorflow as tf
import numpy as np

def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # The `x` arrays are in uint8 and have values in the range [0, 255].
  # You need to convert them to float32 with values in the range [0, 1]
  x_train = x_train / np.float32(255)
  y_train = y_train.astype(np.int64)
  train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000)
  return train_dataset

def dataset_fn(global_batch_size, input_context):
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)
  dataset = mnist_dataset(batch_size)
  dataset = dataset.shard(input_context.num_input_pipelines,
                          input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  return dataset

def build_cnn_model():
  regularizer = tf.keras.regularizers.L2(1e-5)
  return tf.keras.Sequential([
      tf.keras.Input(shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3,
                             activation='relu',
                             kernel_regularizer=regularizer),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128,
                            activation='relu',
                            kernel_regularizer=regularizer),
      tf.keras.layers.Dense(10, kernel_regularizer=regularizer)
  ])
Writing mnist.py

マルチワーカー構成

これから、マルチワーカートレーニングを見ていきます。TensorFlow では、複数のマシンでのトレーニングに 'TF_CONFIG' 環境変数が必要です。各マシンは異なる役割を持つ場合があります。以下で使用される 'TF_CONFIG' 変数は、クラスタの一部である各ワーカーのクラスタ構成を指定する JSON 文字列です。これは、cluster_resolver.TFConfigClusterResolver を使用してクラスタを指定するためのデフォルトの方法ですが、distribute.cluster_resolver モジュールで利用可能な他のオプションがあります。'TF_CONFIG' 変数の設定の詳細は、分散型トレーニングガイドを参照してください。

クラスタについて説明する

以下に構成の例を示します。

tf_config = {
    'cluster': {
        'worker': ['localhost:12345', 'localhost:23456']
    },
    'task': {'type': 'worker', 'index': 0}
}

tf_config は Python の単なるローカル変数であることに注意してください。トレーニング構成に使用するには、JSON としてシリアル化し、'TF_CONFIG' 環境変数に配置します。JSON 文字列としてシリアル化された同じ 'TF_CONFIG' を次に示します。

json.dumps(tf_config)
'{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'

TF_CONFIG には、clustertask の 2 つのコンポーネントがあります。

  • 'cluster' はすべてのワーカーで同じであり、トレーニングクラスタに関する情報を提供します。これは、'worker' などのさまざまなタイプのジョブで構成されるディクショナリです。MultiWorkerMirroredStrategy を使用するマルチワーカートレーニングでは、通常、一般的な 'worker' の作業に加えて、チェックポイントの保存や TensorBoard のサマリーファイルの書き込みなど、ほかよりタスクを担う 'worker' が 1 つあります。こういったワーカーは、'chief' ワーカーと呼ばれ、'index' 0 のワーカーがチーフワーカーに指定されるようになっています。

  • 'task' は現在のタスクの情報を提供し、ワーカーごとに異なります。タスクはそのワーカーの 'type''index' を指定します。

この例では、タスクの 'type''worker'、そしてタスクの 'index'0 に指定します。つまり、このような設定を持つマシンが最初のワーカーであり、チーフワーカーとして指定されて他のワーカーよりも多くの作業を実行します。他のマシンには、'TF_CONFIG' 環境変数も設定されており、同一の 'cluster' ディクショナリも必要ですが、タスクの 'type' やタスクの 'index' は、それらのマシンの役割に応じて異なります。

このチュートリアルでは、例として 'localhost' 上に 2 つのワーカーを持つ 'TF_CONFIG' の設定方法を紹介します。実際には、外部 IP アドレスとポートに複数のワーカーを作成し、各ワーカーに適切な 'TF_CONFIG' を設定します。

この例では、2 つのワーカーを使用します。最初のワーカーの 'TF_CONFIG' は上に示されています。2 番目のワーカーには、tf_config['task']['index']=1 を設定します。

ノートブックの環境変数とサブプロセス

サブプロセスは、親から環境変数を継承します。したがって、この Jupyter ノートブックプロセスで環境変数を設定すると、次のようになります。

os.environ['GREETINGS'] = 'Hello TensorFlow!'

サブプロセスから環境変数にアクセスできます。

echo ${GREETINGS}
Hello TensorFlow!

次のセクションでは、これを使用して TF_CONFIG をワーカーサブプロセスに渡します。この方法で実際にジョブを起動することは決してありませんが、このチュートリアルで最小限のマルチワーカーの例を示すためには十分です。

MultiWorkerMirroredStrategy

モデルをトレーニングする前に、まず tf.distribute.MultiWorkerMirroredStrategy のインスタンスを作成します。

strategy = tf.distribute.MultiWorkerMirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/device:CPU:0',)
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CommunicationImplementation.AUTO
2024-01-11 18:07:06.593731: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

注意: tf.distribute.MultiWorkerMirroredStrategy を呼び出すと、'TF_CONFIG' が解析され、TensorFlow の GRPC サーバーが開始されます。したがって、tf.distribute.Strategy をインスタンス化する前に、'TF_CONFIG' 環境変数を設定する必要があります。時間を節約するために、このチュートリアルではサーバーを起動する必要がありません。完全な例は、このチュートリアルの最後のセクションにあります。

tf.distribute.Strategy.scope を使用して、モデルの構築時にストラテジーを使用する必要があることを指定します。これにより、ストラテジーは変数の配置などを制御できます。すべてのワーカーの各デバイスのモデルのレイヤーにすべての変数のコピーが作成されます。

import mnist
with strategy.scope():
  # Model building needs to be within `strategy.scope()`.
  multi_worker_model = mnist.build_cnn_model()

ワーカー間でデータを自動シャーディングする

マルチワーカー トレーニングでは、収束と再現性を確保するためにデータセットのシャーディングが必要です。シャーディングとは、各ワーカーにデータセット全体のサブセットを渡すことを意味し、単一のワーカーでのトレーニングに似たエクスペリエンスを作成するのに役立ちます。以下の例では、tf.distribute のデフォルトの自動シャーディング ポリシーを使用しています。tf.data.experimental.DistributeOptionstf.data.experimental.AutoShardPolicy を設定してカスタマイズすることもできます。詳細については、分散入力チュートリアルシャーディングセクションを参照してください。

per_worker_batch_size = 64
num_workers = len(tf_config['cluster']['worker'])
global_batch_size = per_worker_batch_size * num_workers

with strategy.scope():
  multi_worker_dataset = strategy.distribute_datasets_from_function(
      lambda input_context: mnist.dataset_fn(global_batch_size, input_context))

カスタムトレーニングループを定義してモデルをトレーニングする

オプティマイザーを指定します。

with strategy.scope():
  # The creation of optimizer and train_accuracy needs to be in
  # `strategy.scope()` as well, since they create variables.
  optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')

tf.function でトレーニング ステップを定義します。

@tf.function
def train_step(iterator):
  """Training step function."""

  def step_fn(inputs):
    """Per-Replica step function."""
    x, y = inputs
    with tf.GradientTape() as tape:
      predictions = multi_worker_model(x, training=True)
      per_example_loss = tf.keras.losses.SparseCategoricalCrossentropy(
          from_logits=True,
          reduction=tf.keras.losses.Reduction.NONE)(y, predictions)
      loss = tf.nn.compute_average_loss(per_example_loss)
      model_losses = multi_worker_model.losses
      if model_losses:
        loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))

    grads = tape.gradient(loss, multi_worker_model.trainable_variables)
    optimizer.apply_gradients(
        zip(grads, multi_worker_model.trainable_variables))
    train_accuracy.update_state(y, predictions)
    return loss

  per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
  return strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

チェックポイントの保存と復元

カスタムトレーニングループを作成するときは、Keras コールバックに依存するのではなく、チェックポイントの保存を手動で処理する必要があります。MultiWorkerMirroredStrategy の場合、チェックポイントまたは完全なモデルを保存するには、すべてのワーカーが含まれる必要があることに注意してください。チーフワーカーだけを保存しようとすると、デッドロックが発生する可能性があるためです。ワーカーは、相互に上書きしないように、異なるパスに書き込む必要もあります。ディレクトリを構成する方法の例を次に示します。

from multiprocessing import util
checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')

def _is_chief(task_type, task_id, cluster_spec):
  return (task_type is None
          or task_type == 'chief'
          or (task_type == 'worker'
              and task_id == 0
              and "chief" not in cluster_spec.as_dict()))

def _get_temp_dir(dirpath, task_id):
  base_dirpath = 'workertemp_' + str(task_id)
  temp_dir = os.path.join(dirpath, base_dirpath)
  tf.io.gfile.makedirs(temp_dir)
  return temp_dir

def write_filepath(filepath, task_type, task_id, cluster_spec):
  dirpath = os.path.dirname(filepath)
  base = os.path.basename(filepath)
  if not _is_chief(task_type, task_id, cluster_spec):
    dirpath = _get_temp_dir(dirpath, task_id)
  return os.path.join(dirpath, base)

モデルをトラッキングする tf.train.Checkpoint を 1 つ作成します。これは tf.train.CheckpointManager により管理されるため、最新のチェックポイントのみが保存されます。

epoch = tf.Variable(
    initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')
step_in_epoch = tf.Variable(
    initial_value=tf.constant(0, dtype=tf.dtypes.int64),
    name='step_in_epoch')
task_type, task_id = (strategy.cluster_resolver.task_type,
                      strategy.cluster_resolver.task_id)
# Normally, you don't need to manually instantiate a `ClusterSpec`, but in this
# illustrative example you did not set `'TF_CONFIG'` before initializing the
# strategy. Check out the next section for "real-world" usage.
cluster_spec = tf.train.ClusterSpec(tf_config['cluster'])

checkpoint = tf.train.Checkpoint(
    model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)

write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id,
                                      cluster_spec)
checkpoint_manager = tf.train.CheckpointManager(
    checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

チェックポイントを復元する必要があれば、便利な tf.train.latest_checkpoint 関数を使用して、保存された最新のチェックポイントを見つけることができます (または tf.train.CheckpointManager.restore_or_initialize を呼び出します)。

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
  checkpoint.restore(latest_checkpoint)

チェックポイントを復元した後、カスタムトレーニングループのトレーニングを続行できます。

num_epochs = 3
num_steps_per_epoch = 70

while epoch.numpy() < num_epochs:
  iterator = iter(multi_worker_dataset)
  total_loss = 0.0
  num_batches = 0

  while step_in_epoch.numpy() < num_steps_per_epoch:
    total_loss += train_step(iterator)
    num_batches += 1
    step_in_epoch.assign_add(1)

  train_loss = total_loss / num_batches
  print('Epoch: %d, accuracy: %f, train_loss: %f.'
                %(epoch.numpy(), train_accuracy.result(), train_loss))

  train_accuracy.reset_states()

  # Once the `CheckpointManager` is set up, you're now ready to save, and remove
  # the checkpoints non-chief workers saved.
  checkpoint_manager.save()
  if not _is_chief(task_type, task_id, cluster_spec):
    tf.io.gfile.rmtree(write_checkpoint_dir)

  epoch.assign_add(1)
  step_in_epoch.assign(0)
2024-01-11 18:07:07.596472: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
Epoch: 0, accuracy: 0.828348, train_loss: 0.563694.
Epoch: 1, accuracy: 0.928125, train_loss: 0.244830.
Epoch: 2, accuracy: 0.945201, train_loss: 0.188251.

完全なコードの概要

これまでに説明したすべての手順の概要は、以下のとおりです。

  1. ワーカープロセスを作成します。
  2. 'TF_CONFIG' をワーカープロセスに渡します。
  3. 各ワークプロセスで、トレーニングコードを含む以下のスクリプトを実行します。

File: main.py

Writing main.py

現在のディレクトリには、両方の Python ファイルが含まれています。

ls *.py
main.py
mnist.py

JSON は TF_CONFIG をシリアル化し、環境変数に追加します。

os.environ['TF_CONFIG'] = json.dumps(tf_config)

これで、main.py を実行し、'TF_CONFIG' を使用するワーカープロセスを起動できます。

# first kill any previous runs
%killbgscripts
All background processes were killed.
python main.py &> job_0.log

上記のコマンドについて注意すべき点がいくつかあります。

  1. ノートブック「マジック」である %%bash を使用して、いくつかの bash コマンドを実行します。
  2. このワーカーは終了しないため、--bg フラグを使用して bash プロセスをバックグラウンドで実行します。このワーカーは始める前にすべてのワーカーを待ちます。

バックグラウンドのワーカープロセスはこのノートブックに出力を出力しないため、&> で出力をファイルにリダイレクトし、何が起こったかを検査できます。

プロセスが開始するまで数秒待ちます。

import time
time.sleep(20)

これまでにワーカーのログファイルに出力されたものを検査します。

cat job_0.log
2024-01-11 18:07:12.173295: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 18:07:12.173354: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 18:07:12.174760: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-11 18:07:14.114778: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

ログファイルの最後の行は Started server with target: grpc://localhost:12345 であるはずです。最初のワーカーは準備が整い、他のすべてのワーカーの準備が整うのを待っています。

2 番目のワーカーのプロセスを始めるように tf_config を更新します。

tf_config['task']['index'] = 1
os.environ['TF_CONFIG'] = json.dumps(tf_config)

次に、2 番目のワーカーを起動します。すべてのワーカーがアクティブであるため、これによりトレーニングが開始されます(したがって、このプロセスをバックグラウンドで実行する必要はありません)。

python main.py > /dev/null 2>&1

最初のワーカーにより書き込まれたログを再確認すると、そのモデルのトレーニングに参加していることがわかります。

cat job_0.log
2024-01-11 18:07:12.173295: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 18:07:12.173354: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 18:07:12.174760: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-11 18:07:14.114778: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2024-01-11 18:07:35.302084: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
Epoch: 0, accuracy: 0.823103, train_loss: 0.584635.
Epoch: 1, accuracy: 0.923103, train_loss: 0.254365.
Epoch: 2, accuracy: 0.949888, train_loss: 0.180697.
# Delete the `'TF_CONFIG'`, and kill any background tasks so they don't affect the next section.
os.environ.pop('TF_CONFIG', None)
%killbgscripts
All background processes were killed.

マルチワーカートレーニングの詳細

このチュートリアルでは、マルチワーカーセットアップのカスタムトレーニングループのワークフローを紹介しました。他のトピックの詳細な説明は、カスタムトレーニングループ向けの Keras を使用したマルチワーカートレーニング (tf.keras.Model.fit) チュートリアルを参照してください。

詳細情報

  1. TensorFlow での分散型トレーニングガイドでは、利用可能な分散ストラテジーの概要が説明されています。
  2. 公式モデル。この多くは、複数の分散ストラテジーで実行するように構成できます。
  3. tf.function ガイドのパフォーマンスの改善では、その他のストラテジーや、TensorFlow モデルのパフォーマンスを最適化するために使用できるツールに関する情報が提供されています。