Using DTensors with Keras

View on Run in Google Colab View source on GitHub Download notebook


In this tutorial, you will learn how to use DTensors with Keras.

Through DTensor integration with Keras, you can reuse your existing Keras layers and models to build and train distributed machine learning models.

You will train a multi-layer classification model with the MNIST data. Setting the layout for subclassing model, Sequential model, and functional model will be demonstrated.

This tutorial assumes that you have already read the DTensor programing guide, and are familiar with basic DTensor concepts like Mesh and Layout.

This tutorial is based on Training a neural network on MNIST with Keras.


DTensor (tf.experimental.dtensor) has been part of TensorFlow since the 2.9.0 release.

First, install or upgrade TensorFlow Datasets:

pip install --quiet --upgrade tensorflow-datasets

Next, import TensorFlow and dtensor, and configure TensorFlow to use 8 virtual CPUs.

Even though this example uses virtual CPUs, DTensor works the same way on CPU, GPU or TPU devices.

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.experimental import dtensor
def configure_virtual_cpus(ncpu):
  phy_devices = tf.config.list_physical_devices('CPU')
        [tf.config.LogicalDeviceConfiguration()] * ncpu)


devices = [f'CPU:{i}' for i in range(8)]

Deterministic pseudo-random number generators

One thing you should note is that DTensor API requires each of the running client to have the same random seeds, so that it could have deterministic behavior for initializing the weights. You can achieve this by setting the global seeds in keras via tf.keras.utils.set_random_seed().


Creating a Data Parallel Mesh

This tutorial demonstrates Data Parallel training. Adapting to Model Parallel training and Spatial Parallel training can be as simple as switching to a different set of Layout objects. Refer to the Distributed training with DTensors tutorial for more information on distributed training beyond Data Parallel.

Data Parallel training is a commonly used parallel training scheme, also used by, for example, tf.distribute.MirroredStrategy.

With DTensor, a Data Parallel training loop uses a Mesh that consists of a single 'batch' dimension, where each device runs a replica of the model that receives a shard from the global batch.

mesh = dtensor.create_mesh([("batch", 8)], devices=devices)

As each device runs a full replica of the model, the model variables shall be fully replicated across the mesh (unsharded). As an example, a fully replicated Layout for a rank-2 weight on this Mesh would be as follows:

example_weight_layout = dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh)  # or
example_weight_layout = dtensor.Layout.replicated(mesh, rank=2)

A layout for a rank-2 data tensor on this Mesh would be sharded along the first dimension (sometimes known as batch_sharded),

example_data_layout = dtensor.Layout(['batch', dtensor.UNSHARDED], mesh)  # or
example_data_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)

Create Keras layers with layout

In the data parallel scheme, you usually create your model weights with a fully replicated layout, so that each replica of the model can do calculations with the sharded input data.

In order to configure the layout information for your layers' weights, Keras has exposed an extra parameter in the layer constructor for most of the built-in layers.

The following example builds a small image classification model with fully replicated weight layout. You can specify layout information kernel and bias in tf.keras.layers.Dense via arguments kernel_layout and bias_layout. Most of the built-in keras layers are ready for explicitly specifying the Layout for the layer weights.

unsharded_layout_2d = dtensor.Layout.replicated(mesh, 2)
unsharded_layout_1d = dtensor.Layout.replicated(mesh, 1)
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),

You can check the layout information by examining the layout property on the weights.

for weight in model.weights:
  print(f'Weight name: {} with layout: {weight.layout}')

Load a dataset and build input pipeline

Load a MNIST dataset and configure some pre-processing input pipeline for it. The dataset itself is not associated with any DTensor layout information.

(ds_train, ds_test), ds_info = tfds.load(
    split=['train', 'test'],
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label
batch_size = 128

ds_train =
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(batch_size)
ds_train = ds_train.prefetch(
ds_test =
ds_test = ds_test.batch(batch_size)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(

Define the training logic for the model

Next, define the training and evaluation logic for the model.

As of TensorFlow 2.9, you have to write a custom-training-loop for a DTensor-enabled Keras model. This is to pack the input data with proper layout information, which is not integrated with the standard or tf.keras.Model.eval() functions from Keras. you will get more support in the upcoming release.

def train_step(model, x, y, optimizer, metrics):
  with tf.GradientTape() as tape:
    logits = model(x, training=True)
    # tf.reduce_sum sums the batch sharded per-example loss to a replicated
    # global loss (scalar).
    loss = tf.reduce_sum(tf.keras.losses.sparse_categorical_crossentropy(
        y, logits, from_logits=True))

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  for metric in metrics.values():
    metric.update_state(y_true=y, y_pred=logits)

  loss_per_sample = loss / len(x)
  results = {'loss': loss_per_sample}
  return results
def eval_step(model, x, y, metrics):
  logits = model(x, training=False)
  loss = tf.reduce_sum(tf.keras.losses.sparse_categorical_crossentropy(
        y, logits, from_logits=True))

  for metric in metrics.values():
    metric.update_state(y_true=y, y_pred=logits)

  loss_per_sample = loss / len(x)
  results = {'eval_loss': loss_per_sample}
  return results
def pack_dtensor_inputs(images, labels, image_layout, label_layout):
  num_local_devices = image_layout.mesh.num_local_devices()
  images = tf.split(images, num_local_devices)
  labels = tf.split(labels, num_local_devices)
  images = dtensor.pack(images, image_layout)
  labels = dtensor.pack(labels, label_layout)
  return  images, labels

Metrics and optimizers

When using DTensor API with Keras Metric and Optimizer, you will need to provide the extra mesh information, so that any internal state variables and tensors can work with variables in the model.

  • For an optimizer, DTensor introduces a new experimental namespace keras.dtensor.experimental.optimizers, where many existing Keras Optimizers are extended to receive an additional mesh argument. In future releases, it may be merged with Keras core optimizers.

  • For metrics, you can directly specify the mesh to the constructor as an argument to make it a DTensor compatible Metric.

optimizer = tf.keras.dtensor.experimental.optimizers.Adam(0.01, mesh=mesh)
metrics = {'accuracy': tf.keras.metrics.SparseCategoricalAccuracy(mesh=mesh)}
eval_metrics = {'eval_accuracy': tf.keras.metrics.SparseCategoricalAccuracy(mesh=mesh)}

Train the model

The following example demonstrates how to shard the data from input pipeline on the batch dimension, and train with the model, which has fully replicated weights.

After 3 epochs, the model should achieve about 97% of accuracy:

num_epochs = 3

image_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=4)
label_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)

for epoch in range(num_epochs):
  print("Epoch: ", epoch)
  for metric in metrics.values():
  step = 0
  results = {}
  pbar = tf.keras.utils.Progbar(target=None, stateful_metrics=[])
  for input in ds_train:
    images, labels = input[0], input[1]
    images, labels = pack_dtensor_inputs(
        images, labels, image_layout, label_layout)

    results.update(train_step(model, images, labels, optimizer, metrics))
    for metric_name, metric in metrics.items():
      results[metric_name] = metric.result()

    pbar.update(step, values=results.items(), finalize=False)
    step += 1
  pbar.update(step, values=results.items(), finalize=True)

  for metric in eval_metrics.values():
  for input in ds_test:
    images, labels = input[0], input[1]
    images, labels = pack_dtensor_inputs(
        images, labels, image_layout, label_layout)
    results.update(eval_step(model, images, labels, eval_metrics))

  for metric_name, metric in eval_metrics.items():
    results[metric_name] = metric.result()

  for metric_name, metric in results.items():
    print(f"{metric_name}: {metric.numpy()}")

Specify Layout for existing model code

Often you have models that work well for your use case. Specifying Layout information to each individual layer within the model will be a large amount of work requiring a lot of edits.

To help you easily convert your existing Keras model to work with DTensor API you can use the new tf.keras.dtensor.experimental.LayoutMap API that allow you to specify the Layout from a global point of view.

First, you need to create a LayoutMap instance, which is a dictionary-like object that contains all the Layout you would like to specify for your model weights.

LayoutMap needs a Mesh instance at init, which can be used to provide default replicated Layout for any weights that doesn't have Layout configured. In case you would like all your model weights to be just fully replicated, you can provide empty LayoutMap, and the default mesh will be used to create replicated Layout.

LayoutMap uses a string as key and a Layout as value. There is a behavior difference between a normal Python dict and this class. The string key will be treated as a regex when retrieving the value.

Subclassed Model

Consider the following model defined using the Keras subclassing Model syntax.

class SubclassedModel(tf.keras.Model):

  def __init__(self, name=None):
    self.feature = tf.keras.layers.Dense(16)
    self.feature_2 = tf.keras.layers.Dense(24)
    self.dropout = tf.keras.layers.Dropout(0.1)

  def call(self, inputs, training=None):
    x = self.feature(inputs)
    x = self.dropout(x, training=training)
    return self.feature_2(x)

There are 4 weights in this model, which are kernel and bias for two Dense layers. Each of them are mapped based on the object path:

  • model.feature.kernel
  • model.feature.bias
  • model.feature_2.kernel
  • model.feature_2.bias

Now define the following LayoutMap and apply it to the model:

layout_map = tf.keras.dtensor.experimental.LayoutMap(mesh=mesh)

layout_map['feature.*kernel'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)
layout_map['feature.*bias'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)

with layout_map.scope():
  subclassed_model = SubclassedModel()

The model weights are created on the first call, so call the model with a DTensor input and confirm the weights have the expected layouts:

dtensor_input = dtensor.copy_to_mesh(tf.zeros((16, 16)), layout=unsharded_layout_2d)
# Trigger the weights creation for subclass model


With this, you can quickly map the Layout to your models without updating any of your existing code.

Sequential and Functional Models

For Keras Functional and Sequential models, you can use tf.keras.dtensor.experimental.LayoutMap as well.

layout_map = tf.keras.dtensor.experimental.LayoutMap(mesh=mesh)

layout_map['feature.*kernel'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)
layout_map['feature.*bias'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)
with layout_map.scope():
  inputs = tf.keras.Input((16,), batch_size=16)
  x = tf.keras.layers.Dense(16, name='feature')(inputs)
  x = tf.keras.layers.Dropout(0.1)(x)
  output = tf.keras.layers.Dense(32, name='feature_2')(x)
  model = tf.keras.Model(inputs, output)

with layout_map.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(16, name='feature', input_shape=(16,)),
      tf.keras.layers.Dense(32, name='feature_2')