TensorFlow 2.0 RC is available Learn more

AutoGraph: Easy control flow for graphs

View on TensorFlow.org Run in Google Colab View source on GitHub

AutoGraph helps you write complicated graph code using normal Python. Behind the scenes, AutoGraph automatically transforms your code into the equivalent TensorFlow graph code. AutoGraph already supports much of the Python language, and that coverage continues to grow. For a list of supported Python language features, see the Autograph capabilities and limitations.


Import TensorFlow, AutoGraph, and any supporting modules:

from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
layers = tf.keras.layers
from tensorflow import contrib
autograph = contrib.autograph

import numpy as np
import matplotlib.pyplot as plt
WARNING: Logging before flag parsing goes to stderr.
W0603 18:06:39.959300 140070722848576 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

We'll enable eager execution for demonstration purposes, but AutoGraph works in both eager and graph execution environments:


Automatically convert Python control flow

AutoGraph will convert much of the Python language into the equivalent TensorFlow graph building code.

AutoGraph converts a function like:

def square_if_positive(x):
  if x > 0:
    x = x * x
    x = 0.0
  return x

To a function that uses graph building:

def tf__square_if_positive(x):
  do_return = False
  retval_ = ag__.UndefinedReturnValue()
  cond = x > 0

  def get_state():
    return ()

  def set_state(_):

  def if_true():
    x_1, = x,
    x_1 = x_1 * x_1
    return x_1

  def if_false():
    x = 0.0
    return x
  x = ag__.if_stmt(cond, if_true, if_false, get_state, set_state)
  do_return = True
  retval_ = x
  cond_1 = ag__.is_undefined_return(retval_)

  def get_state_1():
    return ()

  def set_state_1(_):

  def if_true_1():
    retval_ = None
    return retval_

  def if_false_1():
    return retval_
  retval_ = ag__.if_stmt(cond_1, if_true_1, if_false_1, get_state_1, set_state_1)
  return retval_

Code written for eager execution can run in a tf.Graph with the same results, but with the benfits of graph execution:

print('Eager results: %2.2f, %2.2f' % (square_if_positive(tf.constant(9.0)),
Eager results: 81.00, 0.00

Generate a graph-version and call it:

tf_square_if_positive = autograph.to_graph(square_if_positive)

with tf.Graph().as_default():
  # The result works like a regular op: takes tensors in, returns tensors.
  # You can inspect the graph using tf.get_default_graph().as_graph_def()
  g_out1 = tf_square_if_positive(tf.constant( 9.0))
  g_out2 = tf_square_if_positive(tf.constant(-9.0))
  with tf.Session() as sess:
    print('Graph results: %2.2f, %2.2f\n' % (sess.run(g_out1), sess.run(g_out2)))
Graph results: 81.00, 0.00

AutoGraph supports common Python statements like while, for, if, break, and return, with support for nesting. Compare this function with the complicated graph verson displayed in the following code blocks:

# Continue in a loop
def sum_even(items):
  s = 0
  for c in items:
    if c % 2 > 0:
    s += c
  return s

print('Eager result: %d' % sum_even(tf.constant([10,12,15,20])))

tf_sum_even = autograph.to_graph(sum_even)

with tf.Graph().as_default(), tf.Session() as sess:
    print('Graph result: %d\n\n' % sess.run(tf_sum_even(tf.constant([10,12,15,20]))))
Eager result: 42
Graph result: 42

def tf__sum_even(items):
  do_return = False
  retval_ = ag__.UndefinedReturnValue()
  s = 0

  def loop_body(loop_vars, s_2):
    c = loop_vars
    continue_ = False
    cond = c % 2 > 0

    def get_state():
      return ()

    def set_state(_):

    def if_true():
      continue_ = True
      return continue_

    def if_false():
      return continue_
    continue_ = ag__.if_stmt(cond, if_true, if_false, get_state, set_state)
    cond_1 = ag__.not_(continue_)

    def get_state_1():
      return ()

    def set_state_1(_):

    def if_true_1():
      s_1, = s_2,
      s_1 += c
      return s_1

    def if_false_1():
      return s_2
    s_2 = ag__.if_stmt(cond_1, if_true_1, if_false_1, get_state_1, set_state_1)
    return s_2,
  s, = ag__.for_stmt(items, None, loop_body, (s,))
  do_return = True
  retval_ = s
  cond_2 = ag__.is_undefined_return(retval_)

  def get_state_2():
    return ()

  def set_state_2(_):

  def if_true_2():
    retval_ = None
    return retval_

  def if_false_2():
    return retval_
  retval_ = ag__.if_stmt(cond_2, if_true_2, if_false_2, get_state_2, set_state_2)
  return retval_


If you don't need easy access to the original Python function, use the convert decorator:

def fizzbuzz(i, n):
  while i < n:
    msg = ''
    if i % 3 == 0:
      msg += 'Fizz'
    if i % 5 == 0:
      msg += 'Buzz'
    if msg == '':
      msg = tf.as_string(i)
    i += 1
  return i

with tf.Graph().as_default():
  final_i = fizzbuzz(tf.constant(10), tf.constant(16))
  # The result works like a regular op: takes tensors in, returns tensors.
  # You can inspect the graph using tf.get_default_graph().as_graph_def()
  with tf.Session() as sess:

W0603 18:06:41.537476 140053481240320 backprop.py:899] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.string


W0603 18:06:41.542494 140053481240320 backprop.py:899] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.string


W0603 18:06:41.545793 140053489633024 backprop.py:899] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.string


W0603 18:06:41.549149 140053489633024 backprop.py:899] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.string


W0603 18:06:41.575256 140053489633024 backprop.py:899] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.string



Let's demonstrate some useful Python language features.


AutoGraph automatically converts the Python assert statement into the equivalent tf.Assert code:

def inverse(x):
  assert x != 0.0, 'Do not pass zero!'
  return 1.0 / x

with tf.Graph().as_default(), tf.Session() as sess:
  except tf.errors.InvalidArgumentError as e:
    print('Got error message:\n    %s' % e.message)
Got error message:
    assertion failed: [Do not pass zero!]
     [[node inverse/Assert/Assert (defined at <ipython-input-10-a599015a46ad>:3) ]]


Optionally, you may use the Python print function in-graph, when combined with the autmoatic control dependency management of tf.function:

def count(n):
  i = 0
  while i < n:
    i += 1
  return n

with tf.Graph().as_default(), tf.Session() as sess:


Append to lists in loops (tensor list ops are automatically created):

def arange(n):
  z = []
  # We ask you to tell us the element dtype of the list
  autograph.set_element_type(z, tf.int32)

  for i in tf.range(n):
  # when you're done with the list, stack it
  # (this is just like np.stack)
  return autograph.stack(z)

with tf.Graph().as_default(), tf.Session() as sess:
[0 1 2 3 4 5 6 7 8 9]

Nested control flow

def nearest_odd_square(x):
  if x > 0:
    x = x * x
    if x % 2 == 0:
      x = x + 1
  return x

with tf.Graph().as_default():
  with tf.Session() as sess:

While loop

def square_until_stop(x, y):
  while x < y:
    x = x * x
  return x

with tf.Graph().as_default():
  with tf.Session() as sess:
    print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))

For loop

def squares(nums):

  result = []
  autograph.set_element_type(result, tf.int64)

  for num in nums:
    result.append(num * num)

  return autograph.stack(result)

with tf.Graph().as_default():
  with tf.Session() as sess:
[ 0  1  4  9 16 25 36 49 64 81]


def argwhere_cumsum(x, threshold):
  current_sum = 0.0
  idx = 0
  for i in tf.range(len(x)):
    idx = i
    if current_sum >= threshold:
    current_sum += x[i]
  return idx

N = 10
with tf.Graph().as_default():
  with tf.Session() as sess:
    idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))

Interoperation with tf.Keras

Now that you've seen the basics, let's build some model components with autograph.

It's relatively simple to integrate autograph with tf.keras.

Stateless functions

For stateless functions, like collatz shown below, the easiest way to include them in a keras model is to wrap them up as a layer using tf.keras.layers.Lambda.

import numpy as np

def collatz(x):
  x = tf.reshape(x,())
  assert x > 0
  n = tf.convert_to_tensor((0,))
  while x != 1:
    n += 1
    if x % 2 == 0:
      x = x // 2
      x = 3 * x + 1

  return n

with tf.Graph().as_default():
  model = tf.keras.Sequential([
    tf.keras.layers.Lambda(collatz, input_shape=(1,), output_shape=())

  result = model.predict(np.array([6171]))

Custom Layers and Models

The easiest way to use AutoGraph with Keras layers and models is to @autograph.convert() the call method. See the TensorFlow Keras guide for details on how to build on these classes.

Here is a simple example of the stochastic network depth technique :

# `K` is used to check if we're in train or test mode.
K = tf.keras.backend

class StochasticNetworkDepth(tf.keras.Sequential):
  def __init__(self, layers, pfirst=1.0, plast=0.5,**kwargs):
    self.pfirst = pfirst
    self.plast = plast
    super(StochasticNetworkDepth, self).__init__(layers,**kwargs)

  def build(self, input_shape):
    self.depth = len(self.layers)
    self.plims = np.linspace(self.pfirst, self.plast, self.depth + 1)[:-1]
    super(StochasticNetworkDepth, self).build(input_shape.as_list())

  def call(self, inputs):
    training = tf.cast(K.learning_phase(), dtype=bool)
    if not training:
      count = self.depth
      return super(StochasticNetworkDepth, self).call(inputs), count

    p = tf.random_uniform((self.depth,))

    keeps = (p <= self.plims)
    x = inputs

    count = tf.reduce_sum(tf.cast(keeps, tf.int32))
    for i in range(self.depth):
      if keeps[i]:
        x = self.layers[i](x)

    # return both the final-layer output and the number of layers executed.
    return x, count

Let's try it on mnist-shaped data:

train_batch = np.random.randn(64, 28, 28, 1).astype(np.float32)

Build a simple stack of conv layers, in the stochastic depth model:

with tf.Graph().as_default() as g:
  model = StochasticNetworkDepth(
        layers.Conv2D(filters=16, activation=tf.nn.relu,
                  kernel_size=(3, 3), padding='same')
        for n in range(20)
      pfirst=1.0, plast=0.5

  model.build(tf.TensorShape((None, None, None, 1)))

  init = tf.global_variables_initializer()
W0603 18:06:43.573172 140070722848576 deprecation_wrapper.py:118] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/converters/directives.py:117: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

Now test it to ensure it behaves as expected in train and test modes:

# Use an explicit session here so we can set the train/test switch, and
# inspect the layer count returned by `call`
with tf.Session(graph=g) as sess:

  for phase, name in enumerate(['test','train']):
    result, count = model(tf.convert_to_tensor(train_batch, dtype=tf.float32))

    result1, count1 = sess.run((result, count))
    result2, count2 = sess.run((result, count))

    delta = (result1 - result2)
    print(name, "sum abs delta: ", abs(delta).mean())
    print("    layers 1st call: ", count1)
    print("    layers 2nd call: ", count2)
test sum abs delta:  0.0
    layers 1st call:  20
    layers 2nd call:  20

train sum abs delta:  0.00088969397
    layers 1st call:  16
    layers 2nd call:  16

Advanced example: An in-graph training loop

The previous section showed that AutoGraph can be used inside Keras layers and models. Keras models can also be used in AutoGraph code.

Since writing control flow in AutoGraph is easy, running a training loop in a TensorFlow graph should also be easy.

This example shows how to train a simple Keras model on MNIST with the entire training process—loading batches, calculating gradients, updating parameters, calculating validation accuracy, and repeating until convergence—is performed in-graph.

Download data

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

Define the model

def mlp_model(input_shape):
  model = tf.keras.Sequential((
      tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),
      tf.keras.layers.Dense(100, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')))
  return model

def predict(m, x, y):
  y_p = m(tf.reshape(x, (-1, 28 * 28)))
  losses = tf.keras.losses.categorical_crossentropy(y, y_p)
  l = tf.reduce_mean(losses)
  accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)
  accuracy = tf.reduce_mean(accuracies)
  return l, accuracy

def fit(m, x, y, opt):
  l, accuracy = predict(m, x, y)
  # Autograph automatically adds the necessary <a href="./../api_docs/python/tf/control_dependencies"><code>tf.control_dependencies</code></a> here.
  # (Without them nothing depends on `opt.minimize`, so it doesn't run.)
  # This makes it much more like eager-code.
  return l, accuracy

def setup_mnist_data(is_training, batch_size):
  if is_training:
    ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
    ds = ds.shuffle(batch_size * 10)
    ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))

  ds = ds.repeat()
  ds = ds.batch(batch_size)
  return ds

def get_next_batch(ds):
  itr = ds.make_one_shot_iterator()
  image, label = itr.get_next()
  x = tf.to_float(image) / 255.0
  y = tf.one_hot(tf.squeeze(label), 10)
  return x, y

Define the training loop

# Use `recursive = True` to recursively convert functions called by this one.
@autograph.convert(recursive=True, optional_features=autograph.Feature.ALL)
def train(train_ds, test_ds, hp):
  m = mlp_model((28 * 28,))
  opt = tf.train.AdamOptimizer(hp.learning_rate)

  # We'd like to save our losses to a list. In order for AutoGraph
  # to convert these lists into their graph equivalent,
  # we need to specify the element type of the lists.
  train_losses = []
  autograph.set_element_type(train_losses, tf.float32)
  test_losses = []
  autograph.set_element_type(test_losses, tf.float32)
  train_accuracies = []
  autograph.set_element_type(train_accuracies, tf.float32)
  test_accuracies = []
  autograph.set_element_type(test_accuracies, tf.float32)

  # This entire training loop will be run in-graph.
  i = tf.constant(0)
  while i < hp.max_steps:
    train_x, train_y = get_next_batch(train_ds)
    test_x, test_y = get_next_batch(test_ds)

    step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)
    step_test_loss, step_test_accuracy = predict(m, test_x, test_y)
    if i % 50 == 0:
      print('Step', i, 'train loss:', step_train_loss, 'test loss:',
            step_test_loss, 'train accuracy:', step_train_accuracy,
            'test accuracy:', step_test_accuracy)
    i += 1

  # We've recorded our loss values and accuracies
  # to a list in a graph with AutoGraph's help.
  # In order to return the values as a Tensor,
  # we need to stack them before returning them.
  return (autograph.stack(train_losses), autograph.stack(test_losses),
          autograph.stack(train_accuracies), autograph.stack(test_accuracies))

Now build the graph and run the training loop:

with tf.Graph().as_default() as g:
  hp = tf.contrib.training.HParams(
  train_ds = setup_mnist_data(True, 50)
  test_ds = setup_mnist_data(False, 1000)
  (train_losses, test_losses, train_accuracies,
   test_accuracies) = train(train_ds, test_ds, hp)

  init = tf.global_variables_initializer()

with tf.Session(graph=g) as sess:
  (train_losses, test_losses, train_accuracies,
   test_accuracies) = sess.run([train_losses, test_losses, train_accuracies,
W0603 18:06:47.558233 140070722848576 deprecation_wrapper.py:118] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/converters/directives.py:117: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.

W0603 18:06:48.710157 140070722848576 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py:281: DatasetV1.make_one_shot_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using <a href="./../api_docs/python/tf/estimator"><code>tf.estimator</code></a>, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.
W0603 18:06:49.090992 140070722848576 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py:281: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use <a href="./../api_docs/python/tf/dtypes/cast"><code>tf.cast</code></a> instead.

Step 0 train loss: 2.3129258 test loss: 2.3331568 train accuracy: 0.12 test accuracy: 0.098
Step 50 train loss: 0.2643618 test loss: 0.4459182 train accuracy: 0.92 test accuracy: 0.857
Step 100 train loss: 0.23427086 test loss: 0.3657646 train accuracy: 0.92 test accuracy: 0.885
Step 150 train loss: 0.30767873 test loss: 0.30344555 train accuracy: 0.9 test accuracy: 0.903
Step 200 train loss: 0.10802757 test loss: 0.2970772 train accuracy: 0.96 test accuracy: 0.899
Step 250 train loss: 0.047591947 test loss: 0.29321504 train accuracy: 1.0 test accuracy: 0.905
Step 300 train loss: 0.16977023 test loss: 0.21266425 train accuracy: 0.92 test accuracy: 0.927
Step 350 train loss: 0.12944406 test loss: 0.18370679 train accuracy: 0.96 test accuracy: 0.942
Step 400 train loss: 0.022496855 test loss: 0.20348097 train accuracy: 1.0 test accuracy: 0.942
Step 450 train loss: 0.23636939 test loss: 0.19298728 train accuracy: 0.9 test accuracy: 0.943
plt.title('MNIST train/test losses')
plt.plot(train_losses, label='train loss')
plt.plot(test_losses, label='test loss')
plt.xlabel('Training step')
plt.title('MNIST train/test accuracies')
plt.plot(train_accuracies, label='train accuracy')
plt.plot(test_accuracies, label='test accuracy')
plt.legend(loc='lower right')
plt.xlabel('Training step')