TensorFlow 2.0 Beta is available Learn more

Custom training with TPUs

View on TensorFlow.org Run in Google Colab View source on GitHub

This tutorial will take you through using tf.distribute.experimental.TPUStrategy. This is a new strategy, a part of tf.distribute.Strategy, that allows users to easily switch their model to using TPUs. As part of this tutorial, you will create a Keras model and take it through a custom training loop (instead of calling fit method).

You should be able to understand what is a strategy and why it’s necessary in Tensorflow. This will help you switch between CPU, GPUs, and other device configurations more easily once you understand the strategy framework. To make the introduction easier, you will also make a Keras model that produces a simple convolutional neural network. A Keras model usually is trained in one line of code (by calling its fit method), but because some users require additional customization, we showcase how to use custom training loops. Distribution Strategy was originally written by DeepMind -- you can read the story here.

from __future__ import absolute_import, division, print_function, unicode_literals

# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

assert float('.'.join(tf.__version__.split('.')[:2])) >= 1.14, 'Make sure that Tensorflow version is at least 1.14'
TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']

Create model

Since you will be working with the MNIST data, which is a collection of 70,000 greyscale images representing digits, you want to be using a convolutional neural network to help us with the labeled image data. You will use the Keras API.

def create_model(input_shape):
  """Creates a simple convolutional neural network model using the Keras API"""
  return tf.keras.Sequential([
      tf.keras.layers.Conv2D(28, kernel_size=(3, 3), activation='relu', input_shape=input_shape),
      tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation=tf.nn.relu),
      tf.keras.layers.Dropout(0.2),
      tf.keras.layers.Dense(10, activation=tf.nn.softmax),
  ])

Loss and gradient

Since you are preparing to use a custom training loop, you need to explicitly write down the loss and gradient functions.

def loss(model, x, y):
  """Calculates the loss given an example (x, y)"""
  logits = model(x)
  return logits, tf.losses.sparse_softmax_cross_entropy(labels=y, logits=logits)

def grad(model, x, y):
  """Calculates the loss and the gradients given an example (x, y)"""
  logits, loss_value = loss(model, x, y)
  return logits, loss_value, tf.gradients(loss_value, model.trainable_variables)

Main function

Previous sections highlighted the most important parts of the tutorial. The following code block gives a complete and runnable example of using TPUStrategy with a Keras model and a custom training loop.

tf.keras.backend.clear_session()

resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=TPU_WORKER)
tf.contrib.distribute.initialize_tpu_system(resolver)
strategy = tf.contrib.distribute.TPUStrategy(resolver)

# Load MNIST training and test data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# All MNIST examples are 28x28 pixel greyscale images (hence the 1
# for the number of channels).
input_shape = (28, 28, 1)

# Only specific data types are supported on the TPU, so it is important to
# pay attention to these.
# More information:
# https://cloud.google.com/tpu/docs/troubleshooting#unsupported_data_type
x_train = x_train.reshape(x_train.shape[0], *input_shape).astype(np.float32)
x_test = x_test.reshape(x_test.shape[0], *input_shape).astype(np.float32)
y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64)

# The batch size must be divisible by the number of workers (8 workers),
# so batch sizes of 8, 16, 24, 32, ... are supported.
BATCH_SIZE = 32

NUM_EPOCHS = 5

train_steps_per_epoch = len(x_train) // BATCH_SIZE
test_steps_per_epoch = len(x_test) // BATCH_SIZE
WARNING: Logging before flag parsing goes to stderr.
W0717 22:49:47.853965 140473536648960 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

Start by creating objects within the strategy's scope

Model creation, optimizer creation, etc. must be written in the context of strategy.scope() in order to use TPUStrategy.

Also initialize metrics for the train and test sets. More information: keras.metrics.Mean and keras.metrics.SparseCategoricalAccuracy

with strategy.scope():
  model = create_model(input_shape)

  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

  training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
  training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      'training_accuracy', dtype=tf.float32)
  test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      'test_accuracy', dtype=tf.float32)
W0717 22:50:07.715898 140473536648960 deprecation.py:506] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor

Define custom train and test steps

with strategy.scope():
  def train_step(inputs):
    """Each training step runs this custom function which calculates
    gradients and updates weights.
    """
    x, y = inputs

    logits, loss_value, grads = grad(model, x, y)

    # Show that this is truly a custom training loop
    # Multiply all gradients by 2. 
    grads = grads * 2

    update_vars = optimizer.apply_gradients(
        zip(grads, model.trainable_variables))

    update_loss = training_loss.update_state(loss_value)
    update_accuracy = training_accuracy.update_state(y, logits)

    with tf.control_dependencies([update_vars, update_loss, update_accuracy]):
      return tf.identity(loss_value)

  def test_step(inputs):
    """Each training step runs this custom function"""
    x, y = inputs

    logits, loss_value = loss(model, x, y)

    update_loss = test_loss.update_state(loss_value)
    update_accuracy = test_accuracy.update_state(y, logits)

    with tf.control_dependencies([update_loss, update_accuracy]):
      return tf.identity(loss_value)

Do the training

In order to make the reading a little bit easier, the full training loop calls two helper functions, run_train() and run_test().

def run_train():
  # Train
  session.run(train_iterator_init)
  while True:
    try:
      session.run(dist_train)
    except tf.errors.OutOfRangeError:
      break
  print('Train loss: {:0.4f}\t Train accuracy: {:0.4f}%'.format(
      session.run(training_loss_result),
      session.run(training_accuracy_result) * 100))
  training_loss.reset_states()
  training_accuracy.reset_states()

def run_test():
  # Test
  session.run(test_iterator_init)
  while True:
    try:
      session.run(dist_test)
    except tf.errors.OutOfRangeError:
      break
  print('Test loss: {:0.4f}\t Test accuracy: {:0.4f}%'.format(
      session.run(test_loss_result),
      session.run(test_accuracy_result) * 100))
  test_loss.reset_states()
  test_accuracy.reset_states()
with strategy.scope():
  training_loss_result = training_loss.result()
  training_accuracy_result = training_accuracy.result()
  test_loss_result = test_loss.result()
  test_accuracy_result = test_accuracy.result()
  
  config = tf.ConfigProto()
  config.allow_soft_placement = True
  cluster_spec = resolver.cluster_spec()
  if cluster_spec:
    config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

  print('Starting training...')

  # Do all the computations inside a Session (as opposed to doing eager mode)
  with tf.Session(target=resolver.master(), config=config) as session:
    all_variables = (
        tf.global_variables() + training_loss.variables +
        training_accuracy.variables + test_loss.variables +
        test_accuracy.variables)
    
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(BATCH_SIZE, drop_remainder=True)
    train_iterator = strategy.make_dataset_iterator(train_dataset)

    test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE, drop_remainder=True)
    test_iterator = strategy.make_dataset_iterator(train_dataset)
    
    train_iterator_init = train_iterator.initialize()
    test_iterator_init = test_iterator.initialize()

    session.run([v.initializer for v in all_variables])
    
    dist_train = strategy.experimental_run(train_step, train_iterator).values
    dist_test = strategy.experimental_run(test_step, test_iterator).values

    # Custom training loop
    for epoch in range(0, NUM_EPOCHS):
      print('Starting epoch {}'.format(epoch))

      run_train()

      run_test()
Starting training...

W0717 22:50:11.472496 140473536648960 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/ops/losses/losses_impl.py:121: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Starting epoch 0
Train loss: 1.6253   Train accuracy: 85.3817%
Test loss: 1.5494    Test accuracy: 91.5933%
Starting epoch 1
Train loss: 1.5316   Train accuracy: 93.3117%
Test loss: 1.5238    Test accuracy: 93.9933%
Starting epoch 2
Train loss: 1.5148   Train accuracy: 94.9167%
Test loss: 1.5060    Test accuracy: 95.7167%
Starting epoch 3
Train loss: 1.5058   Train accuracy: 95.7900%
Test loss: 1.5020    Test accuracy: 96.1300%
Starting epoch 4
Train loss: 1.5006   Train accuracy: 96.2333%
Test loss: 1.4979    Test accuracy: 96.5200%