Using DTensors with Keras

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

In this tutoral, you will learn how to use DTensor with Keras.

Through DTensor integration with Keras, you can reuse your existing Keras layers and models to build and train distributed machine learning models.

You will train a multi-layer classification model with the MNIST data. Setting the layout for subclassing model, Sequential model, and functional model will be demonstrated.

This tutoral assumes that you have already read the DTensor programing guide, and are familiar with basic DTensor concepts like Mesh and Layout.

This tutoral is based on https://www.tensorflow.org/datasets/keras_example

Setup

DTensor is part of TensorFlow 2.9.0 release.

pip install --quiet --upgrade --pre tensorflow tensorflow-datasets

Next, import tensorflow and tensorflow.experimental.dtensor, and configure TensorFlow to use 8 virtual CPUs.

Even though this example uses CPUs, DTensor works the same way on CPU, GPU or TPU devices.

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.experimental import dtensor
def configure_virtual_cpus(ncpu):
  phy_devices = tf.config.list_physical_devices('CPU')
  tf.config.set_logical_device_configuration(
        phy_devices[0], 
        [tf.config.LogicalDeviceConfiguration()] * ncpu)

configure_virtual_cpus(8)
tf.config.list_logical_devices('CPU')

devices = [f'CPU:{i}' for i in range(8)]

Deterministic pseudo-random number generators

One thing you should note is that DTensor API requires each of the running client to have the same random seeds, so that it could have deterministic behavior for initializing the weights. You can achieve this by setting the global seeds in keras via tf.keras.utils.set_random_seed().

tf.keras.backend.experimental.enable_tf_random_generator()
tf.keras.utils.set_random_seed(1337)

Creating a Data Parallel Mesh

This tutorial demonstrates Data Parallel training. Adapting to Model Parallel training and Spatial Parallel training can be as simple as switching to a different set of Layout objects. Refer to DTensor in-depth ML Tutorial for more information on distributed training beyond Data Parallel.

Data Parallel training is a commonly used parallel training scheme, also used by for example tf.distribute.MirroredStrategy.

With DTensor, a Data Parallel training loop uses a Mesh that consists of a single 'batch' dimension, where each device runs a replica of the model that receives a shard from the global batch.

mesh = dtensor.create_mesh([("batch", 8)], devices=devices)

As each device runs a full replica of the model, the model variables shall be fully replicated across the mesh (unsharded). As an example, a fully replicated Layout for a rank-2 weight on this Mesh would be as follows:

example_weight_layout = dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh)  # or
example_weight_layout = dtensor.Layout.replicated(mesh, rank=2)

A layout for a rank-2 data tensor on this Mesh would be sharded along the first dimension (sometimes known as batch_sharded),

example_data_layout = dtensor.Layout(['batch', dtensor.UNSHARDED], mesh)  # or
example_data_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)

Create Keras layers with layout

In the data parallel scheme, you usually create your model weights with a fully replicated layout, so that each replica of the model can do calculations with the sharded input data.

In order to configure the layout information for your layers' weights, Keras has exposed an extra parameter in the layer constructor for most of the built-in layers.

The following example builds a small image classification model with fully replicated weight layout. You can specify layout information kernel and bias in tf.keras.layers.Dense via argument kernel_layout and bias_layout. Most of the built-in keras layers are ready for explicitly specifying the Layout for the layer weights.

unsharded_layout_2d = dtensor.Layout.replicated(mesh, 2)
unsharded_layout_1d = dtensor.Layout.replicated(mesh, 1)
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, 
                        activation='relu',
                        name='d1',
                        kernel_layout=unsharded_layout_2d, 
                        bias_layout=unsharded_layout_1d),
  tf.keras.layers.Dense(10,
                        name='d2',
                        kernel_layout=unsharded_layout_2d, 
                        bias_layout=unsharded_layout_1d)
])

You can check the layout information by examining the layout property on the weights.

for weight in model.weights:
  print(f'Weight name: {weight.name} with layout: {weight.layout}')
  break
Weight name: d1/kernel:0 with layout: sharding_specs {
  sharding_spec: "unsharded"
}
sharding_specs {
  sharding_spec: "unsharded"
}
mesh_config {
  mesh_dimensions {
    name: "batch"
    size: 8
  }
  global_device_ids: 0
  global_device_ids: 1
  global_device_ids: 2
  global_device_ids: 3
  global_device_ids: 4
  global_device_ids: 5
  global_device_ids: 6
  global_device_ids: 7
  local_device_ids: 0
  local_device_ids: 1
  local_device_ids: 2
  local_device_ids: 3
  local_device_ids: 4
  local_device_ids: 5
  local_device_ids: 6
  local_device_ids: 7
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:0"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:1"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:2"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:3"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:4"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:5"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:6"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:7"
}

Load a dataset and build input pipeline

Load a MNIST dataset and configure some pre-processing input pipeline for it. The dataset itself is not associated with any DTensor layout information. There are plans to improve DTensor Keras integration with tf.data in future TensorFlow releases.

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label
batch_size = 128

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(batch_size)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(batch_size)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

Define the training logic for the model

Next define the training and evalution logic for the model.

As of TensorFlow 2.9, you have to write a custom-training-loop for a DTensor enabled Keras model. This is to pack the input data with proper layout information, which is not integrated with the standard tf.keras.Model.fit() or tf.keras.Model.eval() functions from Keras. you will get more tf.data support in the upcoming release.

@tf.function
def train_step(model, x, y, optimizer, metrics):
  with tf.GradientTape() as tape:
    logits = model(x, training=True)
    # tf.reduce_sum sums the batch sharded per-example loss to a replicated
    # global loss (scalar).
    loss = tf.reduce_sum(tf.keras.losses.sparse_categorical_crossentropy(
        y, logits, from_logits=True))

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  for metric in metrics.values():
    metric.update_state(y_true=y, y_pred=logits)

  loss_per_sample = loss / len(x)
  results = {'loss': loss_per_sample}
  return results
@tf.function
def eval_step(model, x, y, metrics):
  logits = model(x, training=False)
  loss = tf.reduce_sum(tf.keras.losses.sparse_categorical_crossentropy(
        y, logits, from_logits=True))

  for metric in metrics.values():
    metric.update_state(y_true=y, y_pred=logits)

  loss_per_sample = loss / len(x)
  results = {'eval_loss': loss_per_sample}
  return results
def pack_dtensor_inputs(images, labels, image_layout, label_layout):
  num_local_devices = image_layout.mesh.num_local_devices()
  images = tf.split(images, num_local_devices)
  labels = tf.split(labels, num_local_devices)
  images = dtensor.pack(images, image_layout)
  labels = dtensor.pack(labels, label_layout)
  return  images, labels

Metrics and Optimizers

When using DTensor API with Keras Metric and Optimizer, you will need to provide the extra mesh information, so that any internal state variables and tensors can work with variables in the model.

  • For an optimizer, DTensor introduces a new experimental namespace keras.dtensor.experimental.optimizers, where many existing Keras Optimizers are extended to receive an additional mesh argument. In future releases, it may be merged with Keras core optimizers.

  • For metrics, you can directly specify the mesh to the constructor as an argument to make it a DTensor compatible Metric.

optimizer = tf.keras.dtensor.experimental.optimizers.Adam(0.01, mesh=mesh)
metrics = {'accuracy': tf.keras.metrics.SparseCategoricalAccuracy(mesh=mesh)}
eval_metrics = {'eval_accuracy': tf.keras.metrics.SparseCategoricalAccuracy(mesh=mesh)}

Train the model

The following example shards the data from input pipeline on the batch dimension, and train with the model, which has fully replicated weights.

With 3 epochs, the model should achieve about 97% of accuracy.

num_epochs = 3

image_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=4)
label_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)

for epoch in range(num_epochs):
  print("============================") 
  print("Epoch: ", epoch)
  for metric in metrics.values():
    metric.reset_state()
  step = 0
  results = {}
  pbar = tf.keras.utils.Progbar(target=None, stateful_metrics=[])
  for input in ds_train:
    images, labels = input[0], input[1]
    images, labels = pack_dtensor_inputs(
        images, labels, image_layout, label_layout)

    results.update(train_step(model, images, labels, optimizer, metrics))
    for metric_name, metric in metrics.items():
      results[metric_name] = metric.result()

    pbar.update(step, values=results.items(), finalize=False)
    step += 1
  pbar.update(step, values=results.items(), finalize=True)

  for metric in eval_metrics.values():
    metric.reset_state()
  for input in ds_test:
    images, labels = input[0], input[1]
    images, labels = pack_dtensor_inputs(
        images, labels, image_layout, label_layout)
    results.update(eval_step(model, images, labels, eval_metrics))

  for metric_name, metric in eval_metrics.items():
    results[metric_name] = metric.result()

  for metric_name, metric in results.items():
    print(f"{metric_name}: {metric.numpy()}")
============================
Epoch:  0
    469/Unknown - 7s 15ms/step - loss: 0.2162 - accuracy: 0.8865
loss: 0.14368298649787903
accuracy: 0.9343500137329102
eval_loss: 0.007671673782169819
eval_accuracy: 0.9649999737739563
============================
Epoch:  1
    469/Unknown - 4s 9ms/step - loss: 0.1084 - accuracy: 0.9658
loss: 0.12148343771696091
accuracy: 0.966783344745636
eval_loss: 0.11394938826560974
eval_accuracy: 0.9652000069618225
============================
Epoch:  2
    469/Unknown - 4s 9ms/step - loss: 0.0851 - accuracy: 0.9755
loss: 0.034798476845026016
accuracy: 0.9736833572387695
eval_loss: 0.13114827871322632
eval_accuracy: 0.9664999842643738

Specify Layout for existing model code

Often you have models that work well for your use case. Specifying Layout information to each individual layer within the model will be a large amount of work requiring a lot of edits.

To help you easily convert your existing Keras model to work with DTensor API you can use the new dtensor.LayoutMap API that allow you to specify the Layout from a global point of view.

First, you need to create a LayoutMap instance, which is a dictionary-like object that contains all the Layout you would like to specify for your model weights.

LayoutMap needs a Mesh instance at init, which can be used to provide default replicated Layout for any weights that doesn't have Layout configured. In case you would like all your model weights to be just fully replicated, you can provide empty LayoutMap, and the default mesh will be used to create replicated Layout.

LayoutMap uses a string as key and a Layout as value. There is a behavior difference between a normal Python dict and this class. The string key will be treated as a regex when retrieving the value

Subclassed Model

Consider the following model defined using the Keras subclassing Model syntax.

class SubclassedModel(tf.keras.Model):

  def __init__(self, name=None):
    super().__init__(name=name)
    self.feature = tf.keras.layers.Dense(16)
    self.feature_2 = tf.keras.layers.Dense(24)
    self.dropout = tf.keras.layers.Dropout(0.1)

  def call(self, inputs, training=None):
    x = self.feature(inputs)
    x = self.dropout(x, training=training)
    return self.feature_2(x)

There are 4 weights in this model, which are kernel and bias for two Dense layers. Each of them are mapped based on the object path:

  • model.feature.kernel
  • model.feature.bias
  • model.feature_2.kernel
  • model.feature_2.bias

Now define the following LayoutMap and apply it to the model.

layout_map = tf.keras.dtensor.experimental.LayoutMap(mesh=mesh)

layout_map['feature.*kernel'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)
layout_map['feature.*bias'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)

with tf.keras.dtensor.experimental.layout_map_scope(layout_map):
  subclassed_model = SubclassedModel()

The model weights are created on the first call, so call the model with a DTensor input and confirm the weights have the expected layouts.

dtensor_input = dtensor.copy_to_mesh(tf.zeros((16, 16)), layout=unsharded_layout_2d)
# Trigger the weights creation for subclass model
subclassed_model(dtensor_input)

print(subclassed_model.feature.kernel.layout)
sharding_specs {
  sharding_spec: "batch"
}
sharding_specs {
  sharding_spec: "unsharded"
}
mesh_config {
  mesh_dimensions {
    name: "batch"
    size: 8
  }
  global_device_ids: 0
  global_device_ids: 1
  global_device_ids: 2
  global_device_ids: 3
  global_device_ids: 4
  global_device_ids: 5
  global_device_ids: 6
  global_device_ids: 7
  local_device_ids: 0
  local_device_ids: 1
  local_device_ids: 2
  local_device_ids: 3
  local_device_ids: 4
  local_device_ids: 5
  local_device_ids: 6
  local_device_ids: 7
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:0"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:1"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:2"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:3"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:4"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:5"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:6"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:7"
}

With this, you can quickly map the Layout to your models without updating any of your existing code.

Sequential and Functional Models

For keras functional and sequential models, you can use LayoutMap as well.

layout_map = tf.keras.dtensor.experimental.LayoutMap(mesh=mesh)

layout_map['feature.*kernel'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)
layout_map['feature.*bias'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)
with tf.keras.dtensor.experimental.layout_map_scope(layout_map):
  inputs = tf.keras.Input((16,), batch_size=16)
  x = tf.keras.layers.Dense(16, name='feature')(inputs)
  x = tf.keras.layers.Dropout(0.1)(x)
  output = tf.keras.layers.Dense(32, name='feature_2')(x)
  model = tf.keras.Model(inputs, output)

print(model.layers[1].kernel.layout)
sharding_specs {
  sharding_spec: "batch"
}
sharding_specs {
  sharding_spec: "unsharded"
}
mesh_config {
  mesh_dimensions {
    name: "batch"
    size: 8
  }
  global_device_ids: 0
  global_device_ids: 1
  global_device_ids: 2
  global_device_ids: 3
  global_device_ids: 4
  global_device_ids: 5
  global_device_ids: 6
  global_device_ids: 7
  local_device_ids: 0
  local_device_ids: 1
  local_device_ids: 2
  local_device_ids: 3
  local_device_ids: 4
  local_device_ids: 5
  local_device_ids: 6
  local_device_ids: 7
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:0"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:1"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:2"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:3"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:4"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:5"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:6"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:7"
}
with tf.keras.dtensor.experimental.layout_map_scope(layout_map):
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(16, name='feature', input_shape=(16,)),
      tf.keras.layers.Dropout(0.1),
      tf.keras.layers.Dense(32, name='feature_2')
  ])

print(model.layers[2].kernel.layout)
sharding_specs {
  sharding_spec: "batch"
}
sharding_specs {
  sharding_spec: "unsharded"
}
mesh_config {
  mesh_dimensions {
    name: "batch"
    size: 8
  }
  global_device_ids: 0
  global_device_ids: 1
  global_device_ids: 2
  global_device_ids: 3
  global_device_ids: 4
  global_device_ids: 5
  global_device_ids: 6
  global_device_ids: 7
  local_device_ids: 0
  local_device_ids: 1
  local_device_ids: 2
  local_device_ids: 3
  local_device_ids: 4
  local_device_ids: 5
  local_device_ids: 6
  local_device_ids: 7
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:0"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:1"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:2"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:3"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:4"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:5"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:6"
  local_devices: "/job:localhost/replica:0/task:0/device:CPU:7"
}