![]() |
![]() |
![]() |
Overview
This tutorial demonstrates multi-worker training with custom training loop API, distributed via MultiWorkerMirroredStrategy, so a Keras model designed to run on single-worker can seamlessly work on multiple workers with minimal code change.
We are using custom training loops to train our model because they give us flexibility and a greater control on training. Moreover, it is easier to debug the model and the training loop. More detailed information is available in Writing a training loop from scratch.
If you are looking for how to use MultiWorkerMirroredStrategy
with keras model.fit
, refer to this tutorial instead.
Distributed Training in TensorFlow guide is available for an overview of the distribution strategies TensorFlow supports for those interested in a deeper understanding of tf.distribute.Strategy
APIs.
Setup
First, some necessary imports.
import json
import os
import sys
Before importing TensorFlow, make a few changes to the environment.
Disable all GPUs. This prevents errors caused by the workers all trying to use the same GPU. For a real application each worker would be on a different machine.
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Reset the TF_CONFIG
environment variable, you'll see more about this later.
os.environ.pop('TF_CONFIG', None)
Be sure that the current directory is on python's path. This allows the notebook to import the files written by %%writefile
later.
if '.' not in sys.path:
sys.path.insert(0, '.')
Now import TensorFlow.
import tensorflow as tf
Dataset and model definition
Next create an mnist.py
file with a simple model and dataset setup. This python file will be used by the worker-processes in this tutorial:
%%writefile mnist.py
import os
import tensorflow as tf
import numpy as np
def mnist_dataset(batch_size):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
# The `x` arrays are in uint8 and have values in the range [0, 255].
# You need to convert them to float32 with values in the range [0, 1]
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(60000)
return train_dataset
def dataset_fn(global_batch_size, input_context):
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = mnist_dataset(batch_size)
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
return dataset
def build_cnn_model():
return tf.keras.Sequential([
tf.keras.Input(shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
Writing mnist.py
Multi-worker Configuration
Now let's enter the world of multi-worker training. In TensorFlow, the TF_CONFIG
environment variable is required for training on multiple machines, each of which possibly has a different role. TF_CONFIG
used below, is a JSON string used to specify the cluster configuration on each worker that is part of the cluster. This is the default method for specifying a cluster, using cluster_resolver.TFConfigClusterResolver
, but there are other options available in the distribute.cluster_resolver
module.
Describe your cluster
Here is an example configuration:
tf_config = {
'cluster': {
'worker': ['localhost:12345', 'localhost:23456']
},
'task': {'type': 'worker', 'index': 0}
}
Here is the same TF_CONFIG
serialized as a JSON string:
json.dumps(tf_config)
'{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'
There are two components of TF_CONFIG
: cluster
and task
.
cluster
is the same for all workers and provides information about the training cluster, which is a dict consisting of different types of jobs such asworker
. In multi-worker training withMultiWorkerMirroredStrategy
, there is usually oneworker
that takes on a little more responsibility like saving checkpoint and writing summary file for TensorBoard in addition to what a regularworker
does. Such a worker is referred to as thechief
worker, and it is customary that theworker
withindex
0 is appointed as the chiefworker
(in fact this is howtf.distribute.Strategy
is implemented).task
provides information of the current task and is different on each worker. It specifies thetype
andindex
of that worker.
In this example, you set the task type
to "worker"
and the task index
to 0
. This machine is the first worker and will be appointed as the chief worker and do more work than the others. Note that other machines will need to have the TF_CONFIG
environment variable set as well, and it should have the same cluster
dict, but different task type
or task index
depending on what the roles of those machines are.
For illustration purposes, this tutorial shows how one may set a TF_CONFIG
with 2 workers on localhost
. In practice, users would create multiple workers on external IP addresses/ports, and set TF_CONFIG
on each worker appropriately.
In this example you will use 2 workers, the first worker's TF_CONFIG
is shown above. For the second worker you would set tf_config['task']['index']=1
Above, tf_config
is just a local variable in python. To actually use it to configure training, this dictionary needs to be serialized as JSON, and placed in the TF_CONFIG
environment variable.
Environment variables and subprocesses in notebooks
Subprocesses inherit environment variables from their parent. So if you set an environment variable in this jupyter notebook
process:
os.environ['GREETINGS'] = 'Hello TensorFlow!'
You can access the environment variable from a subprocesses:
echo ${GREETINGS}
Hello TensorFlow!
In the next section, you'll use this to pass the TF_CONFIG
to the worker subprocesses. You would never really launch your jobs this way, but it's sufficient for the purposes of this tutorial: To demonstrate a minimal multi-worker example.
MultiWorkerMirroredStrategy
To train the model, use an instance of tf.distribute.MultiWorkerMirroredStrategy
, which creates copies of all variables in the model's layers on each device across all workers. The tf.distribute.Strategy
guide has more details about this strategy.
strategy = tf.distribute.MultiWorkerMirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/device:CPU:0',) INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CommunicationImplementation.AUTO
Use tf.distribute.Strategy.scope
to specify that a strategy should be used when building your model. This puts you in the "cross-replica context" for this strategy, which means the strategy is put in control of things like variable placement.
import mnist
with strategy.scope():
# Model building needs to be within `strategy.scope()`.
multi_worker_model = mnist.build_cnn_model()
Auto-shard your data across workers
In multi-worker training, dataset sharding is not necessarily needed, however it gives you exactly once semantic which makes more training more reproducible, i.e. training on multiple workers should be the same as training on one worker. Note: performance can be affected in some cases.
See: distribute_datasets_from_function
per_worker_batch_size = 64
num_workers = len(tf_config['cluster']['worker'])
global_batch_size = per_worker_batch_size * num_workers
with strategy.scope():
multi_worker_dataset = strategy.distribute_datasets_from_function(
lambda input_context: mnist.dataset_fn(global_batch_size, input_context))
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step
Define Custom Training Loop and Train the model
Specify an optimizer
with strategy.scope():
# The creation of optimizer and train_accuracy will need to be in
# `strategy.scope()` as well, since they create variables.
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='train_accuracy')
Define a training step with tf.function
@tf.function
def train_step(iterator):
"""Training step function."""
def step_fn(inputs):
"""Per-Replica step function."""
x, y = inputs
with tf.GradientTape() as tape:
predictions = multi_worker_model(x, training=True)
per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE)(y, predictions)
loss = tf.nn.compute_average_loss(
per_batch_loss, global_batch_size=global_batch_size)
grads = tape.gradient(loss, multi_worker_model.trainable_variables)
optimizer.apply_gradients(
zip(grads, multi_worker_model.trainable_variables))
train_accuracy.update_state(y, predictions)
return loss
per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
return strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
Checkpoint saving and restoring
Checkpointing implementation in a Custom Training Loop requires the user to handle it instead of using a keras callback. It allows you to save model's weights and restore them without having to save the whole model.
from multiprocessing import util
checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')
def _is_chief(task_type, task_id):
return task_type is None or task_type == 'chief' or (task_type == 'worker' and
task_id == 0)
def _get_temp_dir(dirpath, task_id):
base_dirpath = 'workertemp_' + str(task_id)
temp_dir = os.path.join(dirpath, base_dirpath)
tf.io.gfile.makedirs(temp_dir)
return temp_dir
def write_filepath(filepath, task_type, task_id):
dirpath = os.path.dirname(filepath)
base = os.path.basename(filepath)
if not _is_chief(task_type, task_id):
dirpath = _get_temp_dir(dirpath, task_id)
return os.path.join(dirpath, base)
Here, you'll create one tf.train.Checkpoint
that tracks the model, which is managed by a tf.train.CheckpointManager
so that only the latest checkpoint is preserved.
epoch = tf.Variable(
initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')
step_in_epoch = tf.Variable(
initial_value=tf.constant(0, dtype=tf.dtypes.int64),
name='step_in_epoch')
task_type, task_id = (strategy.cluster_resolver.task_type,
strategy.cluster_resolver.task_id)
checkpoint = tf.train.Checkpoint(
model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
Now, when you need to restore, you can find the latest checkpoint saved using the convenient tf.train.latest_checkpoint
function.
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
After restoring the checkpoint, you can continue with training your custom training loop.
num_epochs = 3
num_steps_per_epoch = 70
while epoch.numpy() < num_epochs:
iterator = iter(multi_worker_dataset)
total_loss = 0.0
num_batches = 0
while step_in_epoch.numpy() < num_steps_per_epoch:
total_loss += train_step(iterator)
num_batches += 1
step_in_epoch.assign_add(1)
train_loss = total_loss / num_batches
print('Epoch: %d, accuracy: %f, train_loss: %f.'
%(epoch.numpy(), train_accuracy.result(), train_loss))
train_accuracy.reset_states()
# Once the `CheckpointManager` is set up, you're now ready to save, and remove
# the checkpoints non-chief workers saved.
checkpoint_manager.save()
if not _is_chief(task_type, task_id):
tf.io.gfile.rmtree(write_checkpoint_dir)
epoch.assign_add(1)
step_in_epoch.assign(0)
Epoch: 0, accuracy: 0.845312, train_loss: 0.486392. Epoch: 1, accuracy: 0.935938, train_loss: 0.211542. Epoch: 2, accuracy: 0.961272, train_loss: 0.133918.
Full code setup on workers
To actually run with MultiWorkerMirroredStrategy
you'll need to run worker processes and pass a TF_CONFIG
to them.
Like the mnist.py
file written earlier, here is the main.py
that
contain the same code we walked through step by step previously in this colab, we're just writing it to a file so each of the workers will run it:
File: main.py
%%writefile main.py
import os
import json
import tensorflow as tf
import mnist
from multiprocessing import util
per_worker_batch_size = 64
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])
global_batch_size = per_worker_batch_size * num_workers
num_epochs = 3
num_steps_per_epoch=70
# Checkpoint saving and restoring
def _is_chief(task_type, task_id):
return task_type is None or task_type == 'chief' or (task_type == 'worker' and
task_id == 0)
def _get_temp_dir(dirpath, task_id):
base_dirpath = 'workertemp_' + str(task_id)
temp_dir = os.path.join(dirpath, base_dirpath)
tf.io.gfile.makedirs(temp_dir)
return temp_dir
def write_filepath(filepath, task_type, task_id):
dirpath = os.path.dirname(filepath)
base = os.path.basename(filepath)
if not _is_chief(task_type, task_id):
dirpath = _get_temp_dir(dirpath, task_id)
return os.path.join(dirpath, base)
checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')
# Define Strategy
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = mnist.build_cnn_model()
multi_worker_dataset = strategy.distribute_datasets_from_function(
lambda input_context: mnist.dataset_fn(global_batch_size, input_context))
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='train_accuracy')
@tf.function
def train_step(iterator):
"""Training step function."""
def step_fn(inputs):
"""Per-Replica step function."""
x, y = inputs
with tf.GradientTape() as tape:
predictions = multi_worker_model(x, training=True)
per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE)(y, predictions)
loss = tf.nn.compute_average_loss(
per_batch_loss, global_batch_size=global_batch_size)
grads = tape.gradient(loss, multi_worker_model.trainable_variables)
optimizer.apply_gradients(
zip(grads, multi_worker_model.trainable_variables))
train_accuracy.update_state(y, predictions)
return loss
per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
return strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
epoch = tf.Variable(
initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')
step_in_epoch = tf.Variable(
initial_value=tf.constant(0, dtype=tf.dtypes.int64),
name='step_in_epoch')
task_type, task_id = (strategy.cluster_resolver.task_type,
strategy.cluster_resolver.task_id)
checkpoint = tf.train.Checkpoint(
model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
# Restoring the checkpoint
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
# Resume our CTL training
while epoch.numpy() < num_epochs:
iterator = iter(multi_worker_dataset)
total_loss = 0.0
num_batches = 0
while step_in_epoch.numpy() < num_steps_per_epoch:
total_loss += train_step(iterator)
num_batches += 1
step_in_epoch.assign_add(1)
train_loss = total_loss / num_batches
print('Epoch: %d, accuracy: %f, train_loss: %f.'
%(epoch.numpy(), train_accuracy.result(), train_loss))
train_accuracy.reset_states()
checkpoint_manager.save()
if not _is_chief(task_type, task_id):
tf.io.gfile.rmtree(write_checkpoint_dir)
epoch.assign_add(1)
step_in_epoch.assign(0)
Writing main.py
Train and Evaluate
The current directory now contains both Python files:
ls *.py
main.py mnist.py
So json-serialize the TF_CONFIG
and add it to the environment variables:
os.environ['TF_CONFIG'] = json.dumps(tf_config)
Now, you can launch a worker process that will run the main.py
and use the TF_CONFIG
:
# first kill any previous runs
%killbgscripts
All background processes were killed.
python main.py &> job_0.log
There are a few things to note about the above command:
- It uses the
%%bash
which is a notebook "magic" to run some bash commands. - It uses the
--bg
flag to run thebash
process in the background, because this worker will not terminate. It waits for all the workers before it starts.
The backgrounded worker process won't print output to this notebook, so the &>
redirects its output to a file, so you can see what happened.
So, wait a few seconds for the process to start up:
import time
time.sleep(20)
Now look what's been output to the worker's logfile so far:
cat job_0.log
2021-03-31 01:27:04.715571: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0 2021-03-31 01:27:06.241941: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set 2021-03-31 01:27:06.242810: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1 2021-03-31 01:27:07.555688: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected 2021-03-31 01:27:07.555766: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: kokoro-gcp-ubuntu-prod-1943026002 2021-03-31 01:27:07.555775: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: kokoro-gcp-ubuntu-prod-1943026002 2021-03-31 01:27:07.555890: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 450.51.5 2021-03-31 01:27:07.555925: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 450.51.5 2021-03-31 01:27:07.555932: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 450.51.5 2021-03-31 01:27:07.556628: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2021-03-31 01:27:07.557250: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set 2021-03-31 01:27:07.557914: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set 2021-03-31 01:27:07.561940: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job worker -> {0 -> localhost:12345, 1 -> localhost:23456} 2021-03-31 01:27:07.562404: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:411] Started server with target: grpc://localhost:12345
The last line of the log file should say: Started server with target: grpc://localhost:12345
. The first worker is now ready, and is waiting for all the other worker(s) to be ready to proceed.
So update the tf_config
for the second worker's process to pick up:
tf_config['task']['index'] = 1
os.environ['TF_CONFIG'] = json.dumps(tf_config)
Now launch the second worker. This will start the training since all the workers are active (so there's no need to background this process):
python main.py > /dev/null 2>&1
Now if you recheck the logs written by the first worker you'll see that it participated in training that model:
cat job_0.log
2021-03-31 01:27:04.715571: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0 2021-03-31 01:27:06.241941: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set 2021-03-31 01:27:06.242810: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1 2021-03-31 01:27:07.555688: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected 2021-03-31 01:27:07.555766: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: kokoro-gcp-ubuntu-prod-1943026002 2021-03-31 01:27:07.555775: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: kokoro-gcp-ubuntu-prod-1943026002 2021-03-31 01:27:07.555890: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 450.51.5 2021-03-31 01:27:07.555925: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 450.51.5 2021-03-31 01:27:07.555932: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 450.51.5 2021-03-31 01:27:07.556628: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2021-03-31 01:27:07.557250: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set 2021-03-31 01:27:07.557914: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set 2021-03-31 01:27:07.561940: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job worker -> {0 -> localhost:12345, 1 -> localhost:23456} 2021-03-31 01:27:07.562404: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:411] Started server with target: grpc://localhost:12345 2021-03-31 01:27:28.584989: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2) 2021-03-31 01:27:28.585344: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2000189999 Hz Epoch: 0, accuracy: 0.843750, train_loss: 0.499716. Epoch: 1, accuracy: 0.940513, train_loss: 0.197907. Epoch: 2, accuracy: 0.962835, train_loss: 0.132063.
# Delete the `TF_CONFIG`, and kill any background tasks so they don't affect the next section.
os.environ.pop('TF_CONFIG', None)
%killbgscripts
All background processes were killed.
Multi worker training in depth
This tutorial has demonstrated a Custom Training Loop
workflow of the multi-worker setup. A detailed description of other topics is available in the model.fit's guide
of the multi-worker setup and applicable to CTLs.
See also
- Distributed Training in TensorFlow guide provides an overview of the available distribution strategies.
- Official models, many of which can be configured to run multiple distribution strategies.
- The Performance section in the guide provides information about other strategies and tools you can use to optimize the performance of your TensorFlow models.