Concrete functions

View on TensorFlow.org View source on GitHub Download notebook

In the guide to AutoGraph and tf.functions you saw how to use tf.function. This guide dives into the details of:

  • tf.function Tracing
  • tf.function Signatures
  • The Concrete functions generated by tracing:
    • How to access them
    • How to use them

These details only become important:

  • If you're experiencing performance issues due to undesired tracing of a tf.funcion.
  • When you need precise control over the TensorFlow Graphs generated by tf.function. For example for exporting the model to TensorFlow Lite using tf.lite.Converter.from_concrete_functions.

Background

In TensorFlow 2, eager execution is on by default. TensorFlow's eager execution is an imperative programming environment that evaluates operations immediately, without building graphs. Operations return values instead of constructing a computational graph to run later. Here is a detailed guide on eager execution.

Running imperatively makes development and debugging more interactive, but doesn't allow for easy exporting.

The tf.function API makes it possible to save models as graphs.

Terminology

The following terminology is used in this document:

  • Signature - A description of the inputs and outputs for a set of operations.
  • Polymorphic function - Python callable that encapsulates several concrete function graphs behind one API.
  • Concrete function - Graph with a single signature.

Setup

from __future__ import absolute_import, division, print_function, unicode_literals
import traceback
import textwrap

try:
  !pip install -q tf-nightly
except Exception:
  pass
ERROR: tensorflow 2.1.0 has requirement gast==0.2.2, but you'll have gast 0.3.3 which is incompatible.
import tensorflow as tf

Create a tf.function

Annotating a function with tf.function generates a polymorphic function containing those operations. All operations that are not annotated with tf.function will be evaluated with eager execution. The examples below show a quick example of tf.function usage.

@tf.function
def square(x):
  return x*x
square(2).numpy()
4

Remember that the python decorator syntax just calls the decorator with the decorated object as input:

def pow(x,y):
  return x ** y

pow = tf.function(pow)
pow(3,4).numpy()
81

Attach a tf.function method to a tf.Module

The tf.function can be optionally stored as part of a tf.Module object. The tf.Module class provides features for tracking variables and saving checkpoints and models.

Classes like keras.layers.Layer and keras.Model are subclasses of Module.

class Pow(tf.Module):
  def __init__(self, exponent):
    self.exponent = tf.Variable(exponent, dtype = tf.float32, name='Pow/exponent')

  @tf.function
  def __call__(self, x):
    return x ** self.exponent
pow = Pow(3)
pow.variables
(<tf.Variable 'Pow/exponent:0' shape=() dtype=float32, numpy=3.0>,)
pow(tf.constant(2.0)).numpy()
8.0
pow.exponent.assign(4)
pow(tf.constant(2.0)).numpy()
16.0
tf.saved_model.save(pow, 'pow')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1809: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: pow/assets
reloaded_pow = tf.saved_model.load('pow')
reloaded_pow(tf.constant(3.0)).numpy()
81.0

Assign a tf.function as an attribute

If you assign a tf.Module or a tf.function as an attribute of a module it will be serialized as well:

mod = tf.Module()
mod.increment_by = tf.Variable(2.0)

@tf.function
def increment(x):
  return x+mod.increment_by

mod.inc = increment
mod.inc(tf.constant(1.0)).numpy()
3.0
mod.cube = Pow(3)
mod.cube(tf.constant(2.0)).numpy()
8.0
mod.variables
(<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
 <tf.Variable 'Pow/exponent:0' shape=() dtype=float32, numpy=3.0>)
tf.saved_model.save(mod, 'mod')
reloaded_mod = tf.saved_model.load('mod')
INFO:tensorflow:Assets written to: mod/assets
reloaded_mod.inc(4.0).numpy()
6.0
reloaded_mod.cube(4.0).numpy()
64.0

Interoperability with tf.keras

Keras classes like keras.Model and keras.layers.Layer are fully compatible with tf.function and tf.Module.

For example, build a simple model:

linear = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_shape=[1])])
linear.compile(optimizer='adam', loss='mean_squared_error')
linear.fit(x=[-1, 0, 1, 2, 3, 4], y=[-3, -1, 1, 3, 5, 7], epochs=50, verbose=0)
<tensorflow.python.keras.callbacks.History at 0x7f27680569e8>
linear(tf.constant([[1],[2]]))
<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[1.566087 ],
       [3.1825545]], dtype=float32)>

Inspect it's variables

linear.variables
[<tf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[1.6164674]], dtype=float32)>,
 <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([-0.0503803], dtype=float32)>]

Now attach it to a tf.Module:

module = tf.Module()
module.linear = linear

The tf.Module also tracks the tf.Variables:

module.variables
(<tf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[1.6164674]], dtype=float32)>,
 <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([-0.0503803], dtype=float32)>)

The tf.Module will export the contents of the keras.Model as well:

tf.saved_model.save(module,'module')
INFO:tensorflow:Assets written to: module/assets
reloaded = tf.saved_model.load('module')
reloaded.linear([[1.0]])
<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.566087]], dtype=float32)>

Tracing

The objects returned from tf.function are polymorphic functions. They will accept python objects, or tf.Tensors with any shape or tf.dtype as input.

In the background TensorFlow builds tf.Graphs representing the calculation. This graph is wrapped in a python callable: a concrete function. Each concrete function can only handle a single input signature.

tf.function traces the python function each time in needs to create a concrete function. The easiest way to see when a function is traced is to add a call to print:

@tf.function
def mul(a, b):
  print('Tracing:\n    {a}\n    {b}\n'.format(a=a, b=b))
  return a*b

Dtypes and shapes

If you call the polymorphic function with two different types of input, it will trace once for each:

# Trace with ints
mul(tf.constant(2), tf.constant(3)).numpy()
Tracing:
    Tensor("a:0", shape=(), dtype=int32)
    Tensor("b:0", shape=(), dtype=int32)


6
# Trace with floats
mul(tf.constant(2.0), tf.constant(3.0)).numpy()
Tracing:
    Tensor("a:0", shape=(), dtype=float32)
    Tensor("b:0", shape=(), dtype=float32)


6.0

When you call it again with the same input types, it dispatches to an existing function instead of tracing:

# Call with ints again => no trace
mul(tf.constant(10), tf.constant(10))
<tf.Tensor: shape=(), dtype=int32, numpy=100>

Changing the sizes of the inputs also triggers a trace (setting tf.function(experimental_relax_shapes=True) may reduce this):

# Trace with vectors
mul(tf.constant([1.0,3.0]), tf.constant(3.0)).numpy()
Tracing:
    Tensor("a:0", shape=(2,), dtype=float32)
    Tensor("b:0", shape=(), dtype=float32)


array([3., 9.], dtype=float32)
# Trace with different-sized vectors
mul(tf.constant([1.0,2.0,3.0, 4.0]), tf.constant(3.0))
Tracing:
    Tensor("a:0", shape=(4,), dtype=float32)
    Tensor("b:0", shape=(), dtype=float32)


<tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 3.,  6.,  9., 12.], dtype=float32)>

Immutable python objects

If you pass an immutable python object, like a int, str, or tuple to a tf.function, it executes a trace for each value of those python objects.

This is useful to control what gets included in the tf.Graph (See: The Autograph Guide for more details).

@tf.function
def mul(a, b):
  print('Tracing:\n    {a}\n    {b}\n'.format(a=a, b=b))
  return a*b
# Trace for a=3.0
mul(3.0, tf.constant(3.0)).numpy()
Tracing:
    3.0
    Tensor("b:0", shape=(), dtype=float32)


9.0
# Don't trace for a=3.0 the second time:
mul(3.0, tf.constant(3.0)).numpy()
9.0

This loop traces the function for each unique int:

@tf.function
def power(a,b):
  print('Tracing "power": a={}'.format(a))
  return a**b
p = tf.constant(2)
for n in range(12):
  power(n,p)
Tracing "power": a=0
Tracing "power": a=1
Tracing "power": a=2
Tracing "power": a=3
Tracing "power": a=4
WARNING:tensorflow:5 out of the last 5 calls to <function power at 0x7f2710320bf8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
Tracing "power": a=5
WARNING:tensorflow:6 out of the last 6 calls to <function power at 0x7f2710320bf8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
Tracing "power": a=6
WARNING:tensorflow:7 out of the last 7 calls to <function power at 0x7f2710320bf8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
Tracing "power": a=7
WARNING:tensorflow:8 out of the last 8 calls to <function power at 0x7f2710320bf8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
Tracing "power": a=8
WARNING:tensorflow:9 out of the last 9 calls to <function power at 0x7f2710320bf8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
Tracing "power": a=9
WARNING:tensorflow:10 out of the last 10 calls to <function power at 0x7f2710320bf8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
Tracing "power": a=10
WARNING:tensorflow:11 out of the last 11 calls to <function power at 0x7f2710320bf8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
Tracing "power": a=11
WARNING:tensorflow:11 out of the last 11 calls to <function power at 0x7f2710320bf8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.

On the second run each int has been traced, so there's no tracing to do:

p = tf.constant(2)
for n in range(12):
  power(n,p)

To avoid excess retracing be sure to pass a tf.Tensor instead of python numbers or strings:

p = tf.constant(2)
for n in tf.range(12):
  power(n,p)
Tracing "power": a=Tensor("a:0", shape=(), dtype=int32)

To shut off tracing altogether, pass a signature to the tf.function decorator:

@tf.function(input_signature=(
    tf.TensorSpec(shape=[], dtype=tf.float32),
    tf.TensorSpec(shape=[], dtype=tf.float32),)
)
def power_with_sig(a,b):
  print('Tracing "power_with_sig"')
  return a**b
power_with_sig(3.0, 3.0).numpy()
Tracing "power_with_sig"

27.0
try:
  power_with_sig(tf.constant([1.0,2.0,3.0]),tf.constant(3.0))
  assert False
except ValueError:
  traceback.print_exc(limit=1)
Traceback (most recent call last):
  File "<ipython-input-46-344551274fb0>", line 2, in <module>
    power_with_sig(tf.constant([1.0,2.0,3.0]),tf.constant(3.0))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32),
    tf.Tensor(3.0, shape=(), dtype=float32))
  input_signature: (
    TensorSpec(shape=(), dtype=tf.float32, name=None),
    TensorSpec(shape=(), dtype=tf.float32, name=None))

Example: Dropout

Retracing for specific values gives you control over what code gets generated by the tf.function.

class Dropout(tf.Module):
  def __init__(self, rate, name=None):
    super(Dropout, self).__init__(name)
    self.rate = tf.Variable(rate, dtype = tf.float32, trainable=False)

  @tf.function
  def __call__(self, x, training=True):
    print(textwrap.dedent("""
                          Tracing "Dropout":
                              training = {}
                              x = {}
                              name = {:s}
                          """.format(training, x, self.name)))
    if training:
      print('    - Train branch\n')
      mask = tf.random.uniform(x.shape) > self.rate
      return x * tf.cast(mask, tf.float32)/self.rate
    else:
      print('    - Test branch\n')
      return x

Create an instance of this simple Dropout layer:

dropout = Dropout(0.5)

The first time you call it with a python training=True as input, it traces the training branch:

dropout(tf.range(10, dtype=tf.float32), training=True).numpy()

Tracing "Dropout":
    training = True
    x = Tensor("x:0", shape=(10,), dtype=float32)
    name = dropout

    - Train branch


array([ 0.,  2.,  4.,  0.,  8., 10., 12.,  0.,  0.,  0.], dtype=float32)

The second time, it doesn't need to re-trace the branch:

dropout(tf.range(10, dtype=tf.float32), training=True).numpy()
array([ 0.,  2.,  0.,  0.,  8.,  0., 12.,  0., 16.,  0.], dtype=float32)

Passing training=False triggers a trace on the first run since this is a different python value:

dropout(tf.range(10, dtype=tf.float32), training=False).numpy()

Tracing "Dropout":
    training = False
    x = Tensor("x:0", shape=(10,), dtype=float32)
    name = dropout

    - Test branch


array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)
dropout(tf.range(10, dtype=tf.float32), training=False).numpy()
array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)

If you pass a bool tensor, it uses TensorFlow autograph rewrite the if to a tf.condm and traces both branches:

dropout(tf.range(10, dtype=tf.float32), training=tf.constant(False)).numpy()

Tracing "Dropout":
    training = Tensor("training:0", shape=(), dtype=bool)
    x = Tensor("x:0", shape=(10,), dtype=float32)
    name = dropout

    - Train branch

    - Test branch


array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)

This captures the control flow in a single concrete function.

 dropout(tf.range(10, dtype=tf.float32), training=tf.constant(True)).numpy()
array([ 0.,  0.,  4.,  6.,  0.,  0., 12., 14., 16.,  0.], dtype=float32)
dropout(tf.range(10, dtype=tf.float32), training=tf.constant(False)).numpy()
array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)

Other python objects

Since the generated tf.Graphs cannot contain complex python objects, these are included by tracing and variable capture.

The tf.function runs a separate trace for each instance. So each trace includes its own variables, and can set its behavior based on the instance.

The most common usage is on methods of Module, Layer or Module:

dropout_a = Dropout(0.5, name='dropout_a')
print(dropout_a(tf.range(10, dtype=tf.float32), True).numpy())
print(dropout_a(tf.range(10, dtype=tf.float32), True).numpy())

Tracing "Dropout":
    training = True
    x = Tensor("x:0", shape=(10,), dtype=float32)
    name = dropout_a

    - Train branch

[ 0.  2.  4.  6.  0. 10. 12.  0.  0. 18.]
[ 0.  2.  4.  6.  8. 10. 12.  0.  0. 18.]
dropout_b = Dropout(0.5, name='dropout_b')
print(dropout_b(tf.range(10, dtype=tf.float32), True).numpy())
print(dropout_b(tf.range(10, dtype=tf.float32), True).numpy())

Tracing "Dropout":
    training = True
    x = Tensor("x:0", shape=(10,), dtype=float32)
    name = dropout_b

    - Train branch

[ 0.  0.  4.  6.  8. 10. 12.  0.  0.  0.]
[ 0.  2.  4.  0.  0. 10.  0.  0.  0. 18.]

But the behavior is the same on a stand-alone tf.function.

@tf.function
def run(callable, x):
  print('Tracing "run":\n    callable = {}\n    x = {}\n'.format(callable, x))
  return callable(x)
def plus_1(x):
  return x+1

print(run(plus_1, tf.constant(2.0)).numpy())
print(run(plus_1, tf.constant(5.0)).numpy())
Tracing "run":
    callable = <function plus_1 at 0x7f2710243268>
    x = Tensor("x:0", shape=(), dtype=float32)

3.0
6.0

The tracing one tf.function can trigger tracing in another:

print(run(dropout, tf.range(10.0)).numpy())
print(run(dropout, tf.range(10.0)).numpy())
Tracing "run":
    callable = <__main__.Dropout object at 0x7f271023dc88>
    x = Tensor("x:0", shape=(10,), dtype=float32)


Tracing "Dropout":
    training = True
    x = Tensor("x:0", shape=(10,), dtype=float32)
    name = dropout

    - Train branch

[ 0.  0.  4.  6.  8. 10. 12. 14. 16.  0.]
[ 0.  2.  0.  0.  8. 10.  0. 14. 16.  0.]

Weak references

For example here's a tf.function that refers to var from the enclosing scope:

@tf.function
def plus_var(x):
  print('Tracing "plus_var":\n    x = {}\n    var = {}\n\n'.format(x, var.name))
  return x + var

Trace the function with one variable:

var = tf.Variable(1, name="IntVar")
plus_var(tf.constant([1,2])).numpy()
Tracing "plus_var":
    x = Tensor("x:0", shape=(2,), dtype=int32)
    var = IntVar:0



array([2, 3], dtype=int32)

And with another variable:

var = tf.Variable(2.0, name="FloatVar")
plus_var(tf.constant([2.0, 10.0])).numpy()
Tracing "plus_var":
    x = Tensor("x:0", shape=(2,), dtype=float32)
    var = FloatVar:0



array([ 4., 12.], dtype=float32)

That worked, but because you no longer have a reference to "IntVar", that first trace is broken:

try:
  plus_var(tf.constant([1,2])).numpy()
  assert False
except tf.errors.FailedPreconditionError:
  traceback.print_exc(limit=1)
Traceback (most recent call last):
  File "<ipython-input-66-91a1d48c472f>", line 2, in <module>
    plus_var(tf.constant([1,2])).numpy()
tensorflow.python.framework.errors_impl.FailedPreconditionError: 2 root error(s) found.
  (0) Failed precondition:  Error while reading resource variable _AnonymousVar12 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar12/N10tensorflow3VarE does not exist.
     [[node add/ReadVariableOp (defined at <ipython-input-63-2bdc80b2c0ef>:4) ]]
     [[add/ReadVariableOp/_2]]
  (1) Failed precondition:  Error while reading resource variable _AnonymousVar12 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar12/N10tensorflow3VarE does not exist.
     [[node add/ReadVariableOp (defined at <ipython-input-63-2bdc80b2c0ef>:4) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference_plus_var_2179]

Function call stack:
plus_var -> plus_var

Accessing concrete function

In the previous section you saw the conditions for triggering a new trace of a polymorphic tf.function. Each trace generates a new concrete function.

When you save tf.Module as a tf.saved_model It's those concrete functions that define the tf.Graphs that are exported. You don't save a tf.function you save the concrete functions that are created by tracing.

To get a concrete function from the polymorphic tf.function you need to define the signature. Either:

  • Pass an input_signature to tf.function, and call the get_concrete_function() method.
  • Pass a list of tf.TensorSpecs to get_concrete_function: tf.TensorSpec(shape=[1], dtype=tf.float32).
  • Pass an example tensor of the correct shape and type to get_concrete_function: tf.constant(1., shape=[1]).

The following example shows how to define the input_signature parameter for tf.function.

Using input_signature

Specify input tensors in the call to tf.function as shown below. This tf.functioncan only execute on tensors that match the specified signatutre.

A None in the shape acts a wildcard. So this these tf.TensroSpec say "A float32 vector of any length".

This pattern can be very important if your tf.function is expected to handle sequences of different length, or images of different sizes for each batch (See Transformer and Deep Dream tutrorials for example).

@tf.function(input_signature=(
    tf.TensorSpec(shape=[None], dtype=tf.float32),
    tf.TensorSpec(shape=[None], dtype=tf.float32),)
)
def power_with_sig(a,b):
  print('Tracing "power_with_sig"\n')
  return a**b

Calling get_concrete_function will execute the trace (if necessary), and return a concrete function.

p = power_with_sig.get_concrete_function()
type(p)
Tracing "power_with_sig"


tensorflow.python.eager.function.ConcreteFunction
p(tf.constant([2.0,3.0,4.0]), tf.constant([5.0,4.0,3.0])).numpy()
array([32., 81., 64.], dtype=float32)

Using get_concrete_function

@tf.function
def power(a,b):
  print('Tracing "power"\n')
  return a**b
float_power = power.get_concrete_function(
  a = tf.TensorSpec(shape=[], dtype=tf.float32),
  b = tf.TensorSpec(shape=[], dtype=tf.float32))
Tracing "power"

float_power(tf.constant(3.0),tf.constant(3.0))
<tf.Tensor: shape=(), dtype=float32, numpy=27.0>

Remember that you can also pass tensors to get_concrete_function, in that case it returns the concrete function that would run for those inputs:

row = tf.range(10)
col = tf.constant([[1],[2],[3]])

concrete_power = power.get_concrete_function(a = row, b = col)
concrete_power(row, col).numpy()
Tracing "power"


array([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
       [  0,   1,   4,   9,  16,  25,  36,  49,  64,  81],
       [  0,   1,   8,  27,  64, 125, 216, 343, 512, 729]], dtype=int32)

Using a concrete function

A concrete function only accepts tensors as input:

float_power(tf.constant(2.0), tf.constant(3.0)).numpy()
8.0
try:
  float_power(2.0,3.0)
  assert False
except ValueError:
  traceback.print_exc(limit=1)
Traceback (most recent call last):
  File "<ipython-input-75-c045ce959e36>", line 2, in <module>
    float_power(2.0,3.0)
ValueError: All inputs to `ConcreteFunction`s must be Tensors; on invocation of power, the 0-th input (2.0) was not a Tensor.

It also only accepts inputs of the correct dtype:

try:
  float_power(tf.constant(1),tf.constant(3))
  assert False
except tf.errors.InvalidArgumentError:
  traceback.print_exc(limit=1)
Traceback (most recent call last):
  File "<ipython-input-76-9b54dd5e8642>", line 2, in <module>
    float_power(tf.constant(1),tf.constant(3))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_power_2212 as input #0(zero-based) was expected to be a float tensor but is a int32 tensor [Op:__inference_power_2212]

But it will try to execute even if the input tensors do not match the expected shape:

float_power(tf.constant([1.,2.,3.,4.,5.]),tf.constant(3.)).numpy()
array([  1.,   8.,  27.,  64., 125.], dtype=float32)
try:
  float_power(tf.constant([1.,2.,3.]),tf.constant([4., 5.])).numpy()
  assert False
except tf.errors.InvalidArgumentError:  
  traceback.print_exc(limit=1)
Traceback (most recent call last):
  File "<ipython-input-78-c60087f7e9d3>", line 2, in <module>
    float_power(tf.constant([1.,2.,3.]),tf.constant([4., 5.])).numpy()
tensorflow.python.framework.errors_impl.InvalidArgumentError:  Incompatible shapes: [3] vs. [2]
     [[node pow (defined at <ipython-input-70-7cca90b55f0d>:4) ]] [Op:__inference_power_2212]

Errors may have originated from an input operation.
Input Source operations connected to node pow:
 a (defined at <ipython-input-71-63ea774dd467>:3)

Function call stack:
power

By inspecting the concrete function you can see its inputs and outputs:

print(float_power.structured_input_signature)
print(float_power.structured_outputs)
((TensorSpec(shape=(), dtype=tf.float32, name='a'), TensorSpec(shape=(), dtype=tf.float32, name='b')), {})
Tensor("Identity:0", shape=(), dtype=float32)

Python Objects in signatures

As you saw when tracing, each python object generates a new trace. Concrete functions represent a single tf.Graph, they don't do any retracing. When you call get_concrete_function with a python object as one of the arguments the object is bound to the function.

cube = power.get_concrete_function(
    a = tf.TensorSpec([], dtype=tf.float32),
    b = 3.0)
Tracing "power"

This cube function no longer has a b argument:

print(cube.structured_input_signature)
((TensorSpec(shape=(), dtype=tf.float32, name='a'), 3.0), {})
cube(tf.constant(10.0)).numpy()
999.99994

This is very similar to the way that standard python classes bind methods, and applies equally when you run get_concrete_function from a method:

class Greeter(object):
  def __init__(self, greeting):
    self.greeting = greeting

  def greet(self, who):
    return " ".join([self.greeting, who])

p = Greeter("Hello")
m = p.greet
print(m)
<bound method Greeter.greet of <__main__.Greeter object at 0x7f271021e898>>
print(m("TensorFlow!"))
Hello TensorFlow!

When you have a tf.function decorating a method, similar rules apply:

class MyModel(tf.Module):
  def __init__(self, ins, outs):
    initializer = tf.initializers.GlorotNormal()
    self.W = tf.Variable(initializer([ins, outs]))
    self.B = tf.Variable(tf.zeros([outs], dtype = tf.float32))

  @tf.function
  def run(self, x):
    print('Tracing "MyModule":\n    x={}\n'.format(x))
    return tf.matmul(x, self.W)+self.B
mod = MyModel(ins=5, outs=3)
mod.run([[1.0,1.0,1.0, 1.0, 1.0]]).numpy()
Tracing "MyModule":
    x=[[1.0, 1.0, 1.0, 1.0, 1.0]]


array([[0.00368813, 0.07533873, 0.84313065]], dtype=float32)

If you call the method's .get_concrete_function, the self is automatically bound as the first argument:

concrete_run = mod.run.get_concrete_function(x = tf.TensorSpec([None, None]))
Tracing "MyModule":
    x=Tensor("x:0", shape=(None, None), dtype=float32)

concrete_run(tf.constant([[1.0,1.0,1.0, 1.0, 1.0],
                          [2.0,2.0,2.0, 2.0, 2.0]])).numpy()
array([[0.00368813, 0.07533872, 0.84313065],
       [0.00737625, 0.15067744, 1.6862613 ]], dtype=float32)

See how self is no longer part of the input signature:

print(concrete_run.structured_input_signature)
print(concrete_run.structured_outputs)
((TensorSpec(shape=(None, None), dtype=tf.float32, name='x'),), {})
Tensor("Identity:0", shape=(None, 3), dtype=float32)

Accessing concrete functions from a SavedModel

When you save a SavedModel you're really saving the tf.function's cache of concrete functions.

Because concrete functions are generated by tracing the input you need to execute at least one trace to save a SavedModel.

dropout = Dropout(0.5)

_ = dropout(tf.range(10, dtype=tf.float32), tf.constant(True))
_ = dropout(tf.random.normal([2, 3]), tf.constant(True))

Tracing "Dropout":
    training = Tensor("training:0", shape=(), dtype=bool)
    x = Tensor("x:0", shape=(10,), dtype=float32)
    name = dropout

    - Train branch

    - Test branch


Tracing "Dropout":
    training = Tensor("training:0", shape=(), dtype=bool)
    x = Tensor("x:0", shape=(2, 3), dtype=float32)
    name = dropout

    - Train branch

    - Test branch

export_dir = 'dropout'
tf.saved_model.save(dropout, export_dir)

Tracing "Dropout":
    training = Tensor("training:0", shape=(), dtype=bool)
    x = Tensor("x:0", shape=(2, 3), dtype=float32)
    name = dropout

    - Train branch

    - Test branch


Tracing "Dropout":
    training = Tensor("training:0", shape=(), dtype=bool)
    x = Tensor("x:0", shape=(10,), dtype=float32)
    name = dropout

    - Train branch

    - Test branch

INFO:tensorflow:Assets written to: dropout/assets

Direct access

When you load a tf.saved_model your methods are restored as polymorphic functions:

reloaded_dropout = tf.saved_model.load(export_dir)
print(reloaded_dropout(tf.range(10, dtype=tf.float32), tf.constant(False)).numpy())
print(reloaded_dropout(tf.random.normal([2,3]), tf.constant(True)).numpy())
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
[[ 0.        -5.476597  -0.       ]
 [-0.6954242  0.        -0.8201217]]

But since the saved_model only contains the cache of concrete functions (an d not the python source and data), it cannot handle signatures that don't match:

try:
  reloaded_dropout(tf.range(12, dtype=tf.float32), tf.constant(True))
  assert False
except ValueError:
  traceback.print_exc(limit=1)
Traceback (most recent call last):
  File "<ipython-input-95-1a02de4a1d46>", line 2, in <module>
    reloaded_dropout(tf.range(12, dtype=tf.float32), tf.constant(True))
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
  Positional arguments (2 total):
    * Tensor("x:0", shape=(12,), dtype=float32)
    * Tensor("training:0", shape=(), dtype=bool)
  Keyword arguments: {}

Expected these arguments to match one of the following 2 option(s):

Option 1:
  Positional arguments (2 total):
    * TensorSpec(shape=(2, 3), dtype=tf.float32, name='x')
    * TensorSpec(shape=(), dtype=tf.bool, name='training')
  Keyword arguments: {}

Option 2:
  Positional arguments (2 total):
    * TensorSpec(shape=(10,), dtype=tf.float32, name='x')
    * TensorSpec(shape=(), dtype=tf.bool, name='training')
  Keyword arguments: {}

From the reloaded module you can select a specific concrete function instead of relying on the dispatch by, again, using the get_concrete_function method:

cf = reloaded_dropout.__call__.get_concrete_function(
    x = tf.TensorSpec([10]), 
    training = tf.TensorSpec([], tf.bool))
result = cf(tf.range(10, dtype=tf.float32), tf.constant(True)).numpy()
print(result)
[ 0.  0.  0.  0.  0.  0.  0. 14.  0. 18.]

Named signatures: Exporting for C++

C++ consumers of SavedModels do not use the above "Direct Access" method, or it's dynamic dispatch, to get and run concrete functions from the SavedModel.

They use a more explicit interface called "exported signatures", where you specify exactly which concrete functions to export.

You specify the concrete functions to export by passing a signatures argument to tf.saved_model.save.

It takes either:

These signatures are required when using TensorFlow Serving.

Simple example

dropout = Dropout(0.5)
cf = dropout.__call__.get_concrete_function(tf.zeros((2,3), dtype=tf.float32), tf.constant(False))

import time
export_dir = "./saved/"+str(time.time())

tf.saved_model.save(dropout, export_dir, signatures = cf)

Tracing "Dropout":
    training = Tensor("training:0", shape=(), dtype=bool)
    x = Tensor("x:0", shape=(2, 3), dtype=float32)
    name = dropout

    - Train branch

    - Test branch


Tracing "Dropout":
    training = Tensor("training:0", shape=(), dtype=bool)
    x = Tensor("x:0", shape=(2, 3), dtype=float32)
    name = dropout

    - Train branch

    - Test branch

INFO:tensorflow:Assets written to: ./saved/1582252076.061085/assets

This saved_model only contains the one signature, and it can be recovered by name, from the signatures dictionary:

reloaded = tf.saved_model.load(export_dir)

print(reloaded.signatures)
_SignatureMap({'serving_default': <tensorflow.python.saved_model.load._WrapperFunction object at 0x7f2710172630>})

When using a "exported signatures" these concrete functions always return a dictionary of outputs:

cf = reloaded.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
result = cf(x=tf.random.normal([2,3]), training=tf.constant(True))

print(result)
{'output_0': <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.       ,  3.876915 , -0.       ],
       [ 0.       , -1.0137572, -0.4941006]], dtype=float32)>}

In the example above, the output names auto-generated by the signature is fairly generic. You can check the output names using the structured_outputs method:

You can check the expected output-tensor names using the .structured_outputs method:

cf.structured_outputs
{'output_0': TensorSpec(shape=(2, 3), dtype=tf.float32, name='output_0')}

Typically you wannt to set the output names yourself.

Example: Setting the output names

To control the names of the outputs, modify your tf.function to return a dictionary that maps names to output tensors.:

@tf.function
def named_result(x, training=True):
  return {'dropout': dropout(x, training)}

dropout.named_result = named_result

cf = dropout.named_result.get_concrete_function(tf.zeros((2,3), dtype=tf.float32),
                                                tf.constant(False))

Tracing "Dropout":
    training = Tensor("training:0", shape=(), dtype=bool)
    x = Tensor("x:0", shape=(2, 3), dtype=float32)
    name = dropout

    - Train branch

    - Test branch

Example: Setting the signature names

To set the name of the signature pass a dictionary of concrete functions.

export_dir = "./saved/"+str(time.time())
tf.saved_model.save(dropout, export_dir, signatures = {'simple':cf})
INFO:tensorflow:Assets written to: ./saved/1582252076.3178887/assets
reloaded = tf.saved_model.load(export_dir)
cf = reloaded.signatures['simple']
result = cf(x=tf.random.normal([2,3]), training=tf.constant(True))

print({key:value.numpy() for key,value in result.items()})
{'dropout': array([[ 0.,  0., -0.],
       [-0.,  0., -0.]], dtype=float32)}

To specify multiple signatures pass a dictionary of (name, concrete_function) pairs to saved_model.save:

vector = dropout.__call__.get_concrete_function(tf.TensorSpec((2,3), dtype=tf.float32), tf.constant(False))
matrix = dropout.__call__.get_concrete_function(tf.TensorSpec((2,3), dtype=tf.float32), tf.constant(False))
cube = dropout.__call__.get_concrete_function(tf.TensorSpec((2,3), dtype=tf.float32), tf.constant(False))

export_dir = "./saved/"+str(time.time())

tf.saved_model.save(dropout, export_dir, 
                    signatures = {
                        "vector": vector,
                        "matrix": matrix,
                        "cube": cube
                    })

Tracing "Dropout":
    training = Tensor("training:0", shape=(), dtype=bool)
    x = Tensor("x:0", shape=(2, 3), dtype=float32)
    name = dropout

    - Train branch

    - Test branch

INFO:tensorflow:Assets written to: ./saved/1582252076.454253/assets

Now reload that model and inspect the signature listing:

reloaded = tf.saved_model.load(export_dir)
print('{}'.format(reloaded.signatures).replace("{","{\n    ").replace(">, ", ">,\n    "))
_SignatureMap({
    'vector': <tensorflow.python.saved_model.load._WrapperFunction object at 0x7f26e84d16a0>,
    'matrix': <tensorflow.python.saved_model.load._WrapperFunction object at 0x7f26e847c6d8>,
    'cube': <tensorflow.python.saved_model.load._WrapperFunction object at 0x7f26e84eac18>})