TensorFlow 2.0 Beta is available Learn more

tf.distribute.Strategy with Training Loops

View on TensorFlow.org Run in Google Colab View source on GitHub

This tutorial demonstrates how to use tf.distribute.Strategy with custom training loops. We will train a simple CNN model on the fashion MNIST dataset. The fashion MNIST dataset contains 60000 train images of size 28 x 28 and 10000 test images of size 28 x 28.

We are using custom training loops to train our model because they give us flexibility and a greater control on training. Moreover, it is easier to debug the model and the training loop.

from __future__ import absolute_import, division, print_function, unicode_literals

# Import TensorFlow
!pip install -q tf-nightly-gpu
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)
1.14.1-dev20190625

Download the fashion mnist dataset

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

train_labels = train_labels.astype('int64')
test_labels = test_labels.astype('int64')

Create a strategy to distribute the variables and the graph

How does tf.distribute.MirroredStrategy strategy work?

  • All the variables and the model graph is replicated on the replicas.
  • Input is evenly distributed across the replicas.
  • Each replica calculates the loss and gradients for the input it received.
  • The gradients are synced across all the replicas by summing them.
  • After the sync, the same update is made to the copies of the variables on each replica.
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Setup input pipeline

When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly.

BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

tf.distribute.Strategy.experimental_distribute_dataset evenly distributes the dataset across all the replicas.

with strategy.scope():
  train_dataset = tf.data.Dataset.from_tensor_slices(
  (train_images, train_labels)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  train_ds = strategy.experimental_distribute_dataset(train_dataset)

  test_dataset = tf.data.Dataset.from_tensor_slices(
      (test_images, test_labels)).batch(BATCH_SIZE)
  test_ds = strategy.experimental_distribute_dataset(test_dataset)

Model Creation

Create a model using tf.keras.Sequential. You can also use the Model Subclassing API to do this.

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu',
                             input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
    ])
  optimizer = tf.train.GradientDescentOptimizer(0.001)
WARNING: Logging before flag parsing goes to stderr.
W0625 16:44:32.106239 139822012286720 deprecation.py:506] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1624: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

Define the loss function

Normally, on a single machine with 1 GPU/CPU, loss is divided by the number of examples in the batch of input.

So, how should the loss be calculated when using a tf.distribute.Strategy?

  • For an example, let's say you have 4 GPU's and a batch size of 64. One batch of input is distributed across the replicas (4 GPUs), each replica getting an input of size 16.

  • The model on each replica does a forward pass with its respective input and calculates the loss. Now, instead of dividing the loss by the number of examples in its respective input (BATCH_SIZE_PER_REPLICA = 16), the loss should be divided by the GLOBAL_BATCH_SIZE (64).

Why do this?

  • This needs to be done because after the gradients are calculated on each replica, they are synced across the replicas by summing them.

How to do this in TensorFlow? * If you're writing a custom training loop, as in this tutorial, you should sum the per example losses and divide the sum by the GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE) or you can use tf.nn.compute_average_loss which takes the per example loss, optional sample weights, and GLOBAL_BATCH_SIZE as arguments and returns the scaled loss.

  • If you are using regularization losses in your model then you need to scale the loss value by number of replicas. You can do this by using the tf.nn.scale_regularization_loss function.

  • Using tf.reduce_mean is not recommended. Doing so divides the loss by actual per replica batch size which may vary step to step.

  • This reduction and scaling is done automatically in keras model.compile and model.fit

  • If using tf.keras.losses classes, the loss reduction needs to be explicitly specified to be one of NONE or SUM. AUTO and SUM_OVER_BATCH_SIZE are disallowed when used with tf.distribute.Strategy. AUTO is disallowed because the user should explicitly think about what reduction they want to make sure it is correct in the distributed case. SUM_OVER_BATCH_SIZE is disallowed because currently it would only divide by per replica batch size, and leave the dividing by number of replicas to the user, which might be easy to miss. So instead we ask the user do the reduction themselves explicitly.

Training loop

with strategy.scope():
  def train_step(dist_inputs):
    def step_fn(inputs):
      images, labels = inputs
      logits = model(images)
      cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=logits, labels=labels)
      loss = tf.nn.compute_average_loss(cross_entropy, global_batch_size=BATCH_SIZE)
      train_op = optimizer.minimize(loss)
      with tf.control_dependencies([train_op]):
        return tf.identity(loss)

    per_replica_losses = strategy.experimental_run_v2(
        step_fn, args=(dist_inputs,))
    mean_loss = strategy.reduce(
        tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    return mean_loss
with strategy.scope():
  train_iterator = train_ds.make_initializable_iterator()
  iterator_init = train_iterator.initialize()
  var_init = tf.global_variables_initializer()
  loss = train_step(next(train_iterator))
  with tf.Session() as sess:
    sess.run([var_init])
    for epoch in range(EPOCHS):
        sess.run([iterator_init])
        for step in range(10000):
          if step % 1000 == 0:
            print('Epoch {} Step {} Loss {:.4f}'.format(epoch+1,
                                                        step,
                                                        sess.run(loss)))
W0625 16:44:32.375099 139822012286720 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/training/optimizer.py:172: BaseResourceVariable.constraint (from tensorflow.python.ops.resource_variable_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Apply a constraint manually following the optimizer update step.

Epoch 1 Step 0 Loss 2.3023
Epoch 1 Step 1000 Loss 2.3014
Epoch 1 Step 2000 Loss 2.3038
Epoch 1 Step 3000 Loss 2.3024
Epoch 1 Step 4000 Loss 2.3013
Epoch 1 Step 5000 Loss 2.3028
Epoch 1 Step 6000 Loss 2.3034
Epoch 1 Step 7000 Loss 2.3030
Epoch 1 Step 8000 Loss 2.3039
Epoch 1 Step 9000 Loss 2.3029
Epoch 2 Step 0 Loss 2.3025
Epoch 2 Step 1000 Loss 2.3026
Epoch 2 Step 2000 Loss 2.3043
Epoch 2 Step 3000 Loss 2.3023
Epoch 2 Step 4000 Loss 2.3032
Epoch 2 Step 5000 Loss 2.3038
Epoch 2 Step 6000 Loss 2.3025
Epoch 2 Step 7000 Loss 2.3047
Epoch 2 Step 8000 Loss 2.3036
Epoch 2 Step 9000 Loss 2.3028
Epoch 3 Step 0 Loss 2.3041
Epoch 3 Step 1000 Loss 2.3042
Epoch 3 Step 2000 Loss 2.3036
Epoch 3 Step 3000 Loss 2.3023
Epoch 3 Step 4000 Loss 2.3034
Epoch 3 Step 5000 Loss 2.3025
Epoch 3 Step 6000 Loss 2.3030
Epoch 3 Step 7000 Loss 2.3043
Epoch 3 Step 8000 Loss 2.3028
Epoch 3 Step 9000 Loss 2.3031
Epoch 4 Step 0 Loss 2.3021
Epoch 4 Step 1000 Loss 2.3016
Epoch 4 Step 2000 Loss 2.3019
Epoch 4 Step 3000 Loss 2.3026
Epoch 4 Step 4000 Loss 2.3025
Epoch 4 Step 5000 Loss 2.3043
Epoch 4 Step 6000 Loss 2.3036
Epoch 4 Step 7000 Loss 2.3028
Epoch 4 Step 8000 Loss 2.3018
Epoch 4 Step 9000 Loss 2.3035
Epoch 5 Step 0 Loss 2.3026
Epoch 5 Step 1000 Loss 2.3008
Epoch 5 Step 2000 Loss 2.3046
Epoch 5 Step 3000 Loss 2.3013
Epoch 5 Step 4000 Loss 2.3026
Epoch 5 Step 5000 Loss 2.3022
Epoch 5 Step 6000 Loss 2.3038
Epoch 5 Step 7000 Loss 2.3027
Epoch 5 Step 8000 Loss 2.3032
Epoch 5 Step 9000 Loss 2.3029
Epoch 6 Step 0 Loss 2.3044
Epoch 6 Step 1000 Loss 2.3030
Epoch 6 Step 2000 Loss 2.3038
Epoch 6 Step 3000 Loss 2.3037
Epoch 6 Step 4000 Loss 2.3021
Epoch 6 Step 5000 Loss 2.3032
Epoch 6 Step 6000 Loss 2.3027
Epoch 6 Step 7000 Loss 2.3027
Epoch 6 Step 8000 Loss 2.3019
Epoch 6 Step 9000 Loss 2.3023
Epoch 7 Step 0 Loss 2.3031
Epoch 7 Step 1000 Loss 2.3031
Epoch 7 Step 2000 Loss 2.3027
Epoch 7 Step 3000 Loss 2.3008
Epoch 7 Step 4000 Loss 2.3034
Epoch 7 Step 5000 Loss 2.3028
Epoch 7 Step 6000 Loss 2.3037
Epoch 7 Step 7000 Loss 2.3032
Epoch 7 Step 8000 Loss 2.3035
Epoch 7 Step 9000 Loss 2.3031
Epoch 8 Step 0 Loss 2.3037
Epoch 8 Step 1000 Loss 2.3023
Epoch 8 Step 2000 Loss 2.3006
Epoch 8 Step 3000 Loss 2.3040
Epoch 8 Step 4000 Loss 2.3038
Epoch 8 Step 5000 Loss 2.3026
Epoch 8 Step 6000 Loss 2.3016
Epoch 8 Step 7000 Loss 2.3024
Epoch 8 Step 8000 Loss 2.3033
Epoch 8 Step 9000 Loss 2.3025
Epoch 9 Step 0 Loss 2.3018
Epoch 9 Step 1000 Loss 2.3025
Epoch 9 Step 2000 Loss 2.3019
Epoch 9 Step 3000 Loss 2.3027
Epoch 9 Step 4000 Loss 2.3026
Epoch 9 Step 5000 Loss 2.3022
Epoch 9 Step 6000 Loss 2.3034
Epoch 9 Step 7000 Loss 2.3012
Epoch 9 Step 8000 Loss 2.3038
Epoch 9 Step 9000 Loss 2.3013
Epoch 10 Step 0 Loss 2.3023
Epoch 10 Step 1000 Loss 2.3035
Epoch 10 Step 2000 Loss 2.3029
Epoch 10 Step 3000 Loss 2.3030
Epoch 10 Step 4000 Loss 2.3023
Epoch 10 Step 5000 Loss 2.3042
Epoch 10 Step 6000 Loss 2.3022
Epoch 10 Step 7000 Loss 2.3028
Epoch 10 Step 8000 Loss 2.3025
Epoch 10 Step 9000 Loss 2.3029

Next Steps

Try out the new tf.distribute.Strategy API on your models.