マルチワーカー CPU/GPU トレーニングを移行する

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

このガイドでは、マルチワーカーの分散トレーニングワークフローを TensorFlow 1 から TensorFlow 2 に移行する方法を実演します。

CPU/GPU を使用してマルチワーカートレーニングを実行するには

セットアップ

必要とされるインポートとデモ用の単純なデータセットから始めます。

# The notebook uses a dataset instance for `Model.fit` with
# `ParameterServerStrategy`, which depends on symbols in TF 2.7.
# Install a utility needed for this demonstration
!pip install portpicker

import tensorflow as tf
import tensorflow.compat.v1 as tf1
2022-12-14 22:37:03.051332: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:37:03.051429: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:37:03.051438: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]

TensorFlow で複数のマシンでトレーニングするには、'TF_CONFIG' 構成環境変数が必要になります。'TF_CONFIG' を使用して、'cluster''task' のアドレスを指定します。(詳しくは Distributed_training ガイドを参照してください)。

import json
import os

tf_config = {
    'cluster': {
        'chief': ['localhost:11111'],
        'worker': ['localhost:12345', 'localhost:23456', 'localhost:21212'],
        'ps': ['localhost:12121', 'localhost:13131'],
    },
    'task': {'type': 'chief', 'index': 0}
}

os.environ['TF_CONFIG'] = json.dumps(tf_config)

注意: 残念ながら、TensorFlow 1 の tf.estimator API を使用したマルチワーカートレーニングには複数のクライアントが必要になるため、(この Colab ノートブックでこれを実行するのが特に難しくなります)、ローカルトレーニングにフォールバックするように 'TF_CONFIG' 環境変数なしでノートブックが実行できるようにします。(詳細については、TensorFlow を使用した分散トレーニング ガイドの 'TF_CONFIG' 環境変数の設定セクションを参照してください。)

del ステートメントを使用して変数を削除します(ただし、TensorFlow 1 での実際のマルチワーカートレーニングでは、これを行う必要はありません)。

del os.environ['TF_CONFIG']

TensorFlow 1: tf.estimator API を使用したマルチワーカー分散トレーニング

次のコードスニペットは、TF1 でのマルチワーカートレーニングの正規のワークフローを示しています。tf.estimator.Estimatortf.estimator.TrainSpectf.estimator.EvalSpec、およびトレーニングを分散する tf.estimator.train_and_evaluate API を使用します。

def _input_fn():
  return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)

def _eval_input_fn():
  return tf1.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

def _model_fn(features, labels, mode):
  logits = tf1.layers.Dense(1)(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

estimator = tf1.estimator.Estimator(model_fn=_model_fn)
train_spec = tf1.estimator.TrainSpec(input_fn=_input_fn)
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpa8ntmn5y
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpa8ntmn5y', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpa8ntmn5y/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.25171334, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmpa8ntmn5y/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-14T22:37:08
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpa8ntmn5y/model.ckpt-3
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.25907s
INFO:tensorflow:Finished evaluation at 2022-12-14-22:37:08
INFO:tensorflow:Saving dict for global step 3: global_step = 3, loss = 1.3012573
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmpfs/tmp/tmpa8ntmn5y/model.ckpt-3
INFO:tensorflow:Loss for final step: 0.9806307.
({'loss': 1.3012573, 'global_step': 3}, [])

TensorFlow 2: 分散ストラテジーによるマルチワーカートレーニング

TensorFlow 2 では、tf.distribute.Strategy を介して、CPU、GPU、および TPU を使用する複数のワーカーでの分散トレーニングが行われます。

次の例は、2 つのストラテジー tf.distribute.experimental.ParameterServerStrategytf.distribute.MultiWorkerMirroredStrategy の使用方法を示しています。どちらも複数のワーカーによる CPU/GPU トレーニングのために設計されています。

ParameterServerStrategy は、コーディネーター'chief')を採用しているため、この Colab ノートブックの環境により使いやすくなっています。ここでは、いくつかのユーティリティを使用して、実行可能なエクスペリエンスに不可欠なサポート要素をセットアップします。スレッドを使用してパラメータサーバー('ps')とワーカー('worker')をシミュレートする インプロセスクラスタを作成します。パラメータサーバーのトレーニングの詳細については、ParameterServerStrategy を使用したパラメータサーバーのトレーニングのチュートリアルを参照してください。

この例では、まず 'TF_CONFIG' 環境変数を tf.distribute.cluster_resolver.TFConfigClusterResolver で定義して、クラスター情報を提供します。分散トレーニングにクラスター管理システムを使用している場合は、すでに 'TF_CONFIG' が提供されているかどうかを確認してください。提供されている場合、この環境変数を明示的に設定する必要はありません。(詳細については、TensorFlow を使用した分散トレーニングガイドの 'TF_CONFIG' 環境変数の設定セクションを参照してください。)

# Find ports that are available for the `'chief'` (the coordinator),
# `'worker'`s, and `'ps'` (parameter servers).
import portpicker

chief_port = portpicker.pick_unused_port()
worker_ports = [portpicker.pick_unused_port() for _ in range(3)]
ps_ports = [portpicker.pick_unused_port() for _ in range(2)]

# Dump the cluster information to `'TF_CONFIG'`.
tf_config = {
    'cluster': {
        'chief': ["localhost:%s" % chief_port],
        'worker': ["localhost:%s" % port for port in worker_ports],
        'ps':  ["localhost:%s" % port for port in ps_ports],
    },
    'task': {'type': 'chief', 'index': 0}
}
os.environ['TF_CONFIG'] = json.dumps(tf_config)

# Use a cluster resolver to bridge the information to the strategy created below.
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()

次に、ワーカーとパラメータサーバーの tf.distribute.Server を 1 つずつ作成します。

# Workers need some inter_ops threads to work properly.
# This is only needed for this notebook to demo. Real servers
# should not need this.
worker_config = tf.compat.v1.ConfigProto()
worker_config.inter_op_parallelism_threads = 4

for i in range(3):
  tf.distribute.Server(
      cluster_resolver.cluster_spec(),
      job_name="worker",
      task_index=i,
      config=worker_config)

for i in range(2):
  tf.distribute.Server(
      cluster_resolver.cluster_spec(),
      job_name="ps",
      task_index=i)

実際の分散トレーニングでは、コーディネーターですべての tf.distribute.Server を開始せずに、複数のマシンを使用し、"worker""ps"(パラメータサーバー)は、それぞれ tf.distribute.Server を実行します。詳細については、パラメータサーバのトレーニング チュートリアルの実世界のクラスタセクションを参照してください。

すべての準備が整ったら、ParameterServerStrategy オブジェクトを作成します。

strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver)
INFO:tensorflow:`tf.distribute.experimental.ParameterServerStrategy` is initialized with cluster_spec: ClusterSpec({'chief': ['localhost:36219'], 'ps': ['localhost:34739', 'localhost:45569'], 'worker': ['localhost:42897', 'localhost:35265', 'localhost:42741']})
INFO:tensorflow:ParameterServerStrategyV2 is now connecting to cluster with cluster_spec: ClusterSpec({'chief': ['localhost:36219'], 'ps': ['localhost:34739', 'localhost:45569'], 'worker': ['localhost:42897', 'localhost:35265', 'localhost:42741']})
INFO:tensorflow:ParameterServerStrategy (CentralStorageStrategy if you are using a single machine) with compute_devices = ['/job:chief/replica:0/task:0/device:GPU:0', '/job:chief/replica:0/task:0/device:GPU:1', '/job:chief/replica:0/task:0/device:GPU:2', '/job:chief/replica:0/task:0/device:GPU:3'], variable_device = '/device:CPU:0'
INFO:tensorflow:Number of GPUs on workers: 4

ストラテジーオブジェクトを作成したら、モデル、オプティマイザ、およびその他の変数を定義し、Strategy.scope API 内で Keras Model.compile を呼び出してトレーニングを分散します。(詳細については、Strategy.scope API ドキュメントを参照してください。)

例えば、フォワードパスとバックワードパスを定義してトレーニングをカスタマイズする場合は、詳細について、パラメータサーバートレーニングチュートリアルのカスタムトレーニングループを使用したトレーニングセクションを参照してください。

dataset = tf.data.Dataset.from_tensor_slices(
      (features, labels)).shuffle(10).repeat().batch(64)

eval_dataset = tf.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).repeat().batch(1)

with strategy.scope():
  model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
  optimizer = tf.keras.optimizers.legacy.Adagrad(learning_rate=0.05)
  model.compile(optimizer, "mse")

model.fit(dataset, epochs=5, steps_per_epoch=10)
Epoch 1/5
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:461: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
2022-12-14 22:37:09.525475: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 3
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:6"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 2
        }
      }
      shape {
        dim {
          size: 1
        }
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}

2022-12-14 22:37:09.525532: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 3
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:6"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 2
        }
      }
      shape {
        dim {
          size: 1
        }
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}

2022-12-14 22:37:09.525737: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 3
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:6"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 2
        }
      }
      shape {
        dim {
          size: 1
        }
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Waiting for all global closures to be finished.
10/10 - 4s - loss: 12.5010 - 4s/epoch - 418ms/step
Epoch 2/5
INFO:tensorflow:Waiting for all global closures to be finished.
10/10 - 0s - loss: 6.7175 - 68ms/epoch - 7ms/step
Epoch 3/5
INFO:tensorflow:Waiting for all global closures to be finished.
10/10 - 0s - loss: 4.3748 - 66ms/epoch - 7ms/step
Epoch 4/5
INFO:tensorflow:Waiting for all global closures to be finished.
10/10 - 0s - loss: 3.0396 - 66ms/epoch - 7ms/step
Epoch 5/5
INFO:tensorflow:Waiting for all global closures to be finished.
10/10 - 0s - loss: 2.1569 - 65ms/epoch - 6ms/step
<keras.callbacks.History at 0x7f02f0068ca0>
model.evaluate(eval_dataset, steps=10, return_dict=True)
2022-12-14 22:37:14.088484: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 3
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\025TensorSliceDataset:10"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 2
        }
      }
      shape {
        dim {
          size: 1
        }
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}

2022-12-14 22:37:14.088543: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 3
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\025TensorSliceDataset:10"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 2
        }
      }
      shape {
        dim {
          size: 1
        }
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}

2022-12-14 22:37:14.088606: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 3
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\025TensorSliceDataset:10"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 2
        }
      }
      shape {
        dim {
          size: 1
        }
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Waiting for all global closures to be finished.
10/10 - 1s - loss: 12.8000 - 1s/epoch - 120ms/step
{'loss': 12.800030708312988}

パーティショナ(tf.distribute.experimental.practitioners{/code 0})

TensorFlow 2 の ParameterServerStrategy は、変数のパーティショニングをサポートし、TensorFlow 1 と同じパーティショナを提供しますが、紛らわしい名前はありません。

または、MultiWorkerMirroredStrategy オブジェクトを使用できます。

# To clean up the `TF_CONFIG` used for `ParameterServerStrategy`.
del os.environ['TF_CONFIG']
strategy = tf.distribute.MultiWorkerMirroredStrategy()
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.AUTO

上記で使用したストラテジーを MultiWorkerMirroredStrategy オブジェクトに置き換えて、このストラテジーでトレーニングを実行できます。

tf.estimator API と同様に、MultiWorkerMirroredStrategy はマルチクライアントストラテジーであるため、この Colab ノートブックでは簡単に分散トレーニングを実行できません。したがって、上記のコードをこのストラテジーに置き換えると、ローカルで実行されることになります。マルチワーカートレーニング Keras Model.fit を使用/カスタムトレーニングループのチュートリアルでは、Colab のローカルホストで 2 つのワーカーを使用して変数を設定し、'TF_CONFIG' を使用してマルチワーカートレーニングを実行する方法を示しています。実際には、外部 IP アドレス/ポートに複数のワーカーを作成し、'TF_CONFIG' 変数を使用して各ワーカーのクラスター構成を指定します。

次のステップ

TensorFlow 2 の tf.distribute.experimental.ParameterServerStrategytf.distribute.MultiWorkerMirroredStrategy を使用したマルチワーカー分散トレーニングの詳細については、次のリソースを参照してください。