Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Mejor rendimiento con tf.function

Ver en TensorFlow.org Ver fuente en GitHub Descargar cuaderno

En TensorFlow 2, la ejecución ansiosa está activada de forma predeterminada. La interfaz de usuario es intuitiva y flexible (ejecutar operaciones puntuales es mucho más fácil y rápido), pero esto puede tener como resultado el rendimiento y la capacidad de implementación.

Puede usar tf.function para hacer gráficos a partir de sus programas. Es una herramienta de transformación que crea gráficos de flujo de datos independientes de Python a partir de su código Python. Esto le ayudará a crear modelos portátiles y de rendimiento, y es necesario utilizar SavedModel .

Esta guía lo ayudará a conceptualizar cómo funciona tf.function debajo del capó para que pueda usarlo de manera efectiva.

Las principales conclusiones y recomendaciones son:

  • Depura en modo ansioso, luego decora con @tf.function .
  • No confíe en los efectos secundarios de Python, como la mutación de objetos o los anexos de listas.
  • tf.function funciona mejor con las operaciones de TensorFlow; Las llamadas a NumPy y Python se convierten en constantes.

Preparar

import tensorflow as tf

Defina una función auxiliar para demostrar los tipos de errores que puede encontrar:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

Lo esencial

Uso

Una Function que defina es como una operación central de TensorFlow: puede ejecutarla con entusiasmo; puedes calcular gradientes; y así.

@tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Puede utilizar Function s dentro de otras Function s.

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function pueden ser más rápidas que el código ávido, especialmente para gráficos con muchas operaciones pequeñas. Pero para gráficos con algunas operaciones costosas (como convoluciones), es posible que no vea mucha aceleración.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

Eager conv: 0.002407395000091128
Function conv: 0.004000883000117028
Note how there's not much difference in performance for convolutions

Rastreo

La escritura dinámica de Python significa que puede llamar a funciones con una variedad de tipos de argumentos, y Python puede hacer algo diferente en cada escenario.

Sin embargo, para crear un gráfico de TensorFlow, se requieren dtypes estáticos y dimensiones de forma. tf.function cierra esta brecha al tf.function una función de Python para crear un objeto de Function . Según las entradas dadas, la Function selecciona el gráfico apropiado para las entradas dadas, rastreando la función de Python según sea necesario. Una vez que comprenda por qué y cuándo ocurre el rastreo, ¡es mucho más fácil usar tf.function efectiva!

Puede llamar a una Function con argumentos de diferentes tipos para ver este comportamiento polimórfico en acción.

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()

Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)


Tenga en cuenta que si llama repetidamente a una Function con el mismo tipo de argumento, TensorFlow reutilizará un gráfico trazado previamente, ya que el gráfico generado sería idéntico.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

Puede usar pretty_printed_concrete_signatures() para ver todos los rastros disponibles:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Hasta ahora, has visto que tf.function crea una capa de distribución dinámica en caché sobre la lógica de seguimiento de gráficos de TensorFlow. Para ser más específico sobre la terminología:

  • Un tf.Graph es la representación sin procesar, tf.Graph del lenguaje y portátil de su computación.
  • Un ConcreteFunction es un contenedor que se ejecuta con entusiasmo alrededor de un tf.Graph .
  • Una Function gestiona un caché de ConcreteFunction sy elige la correcta para sus entradas.
  • tf.function envuelve una función de Python y devuelve un objeto de Function .

Obtener funciones concretas

Cada vez que se rastrea una función, se crea una nueva función concreta. Puede obtener directamente una función concreta, utilizando get_concrete_function .

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))

Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)

# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'cc', shape=(), dtype=string)

(El siguiente cambio está disponible en TensorFlow todas las noches y estará disponible en TensorFlow 2.3).

La impresión de una función ConcreteFunction muestra un resumen de sus argumentos de entrada (con tipos) y su tipo de salida.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

También puede recuperar directamente la firma de una función concreta.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

El uso de un rastro concreto con tipos incompatibles arrojará un error

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-e4e2860a4364>", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_168 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_168]

Puede notar que los argumentos de Python reciben un tratamiento especial en la firma de entrada de una función concreta. Antes de TensorFlow 2.3, los argumentos de Python simplemente se eliminaron de la firma de la función concreta. A partir de TensorFlow 2.3, los argumentos de Python permanecen en la firma, pero están limitados a tomar el valor establecido durante el seguimiento.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>

assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1669, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1714, in _call_with_flat_signature
    self._flat_signature_summary(), ", ".join(sorted(kwargs))))
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-d163f3d206cb>", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3

Obtención de gráficos

Cada función concreta es una envoltura invocable alrededor de un tf.Graph . Aunque recuperar el objeto tf.Graph real no es algo que normalmente necesitará hacer, puede obtenerlo fácilmente desde cualquier función concreta.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')

[] -> a
['a', 'a'] -> add
['add'] -> Identity

Depuración

En general, depurar el código es más fácil en modo ansioso que dentro de tf.function . Debe asegurarse de que su código se ejecute sin errores en modo ansioso antes de decorar con tf.function . Para ayudar en el proceso de depuración, puede llamar a tf.config.run_functions_eagerly(True) para deshabilitar y volver a habilitar tf.function .

Al rastrear problemas que solo aparecen dentro de tf.function , aquí hay algunos consejos:

  • Las llamadas de print Python simples y antiguas solo se ejecutan durante el seguimiento, lo que le ayuda a rastrear cuándo se (re) rastrea su función.
  • tf.print llamadas a tf.print se ejecutarán cada vez y pueden ayudarlo a rastrear valores intermedios durante la ejecución.
  • tf.debugging.enable_check_numerics es una forma fácil de rastrear dónde se crean NaN e Inf.
  • pdb puede ayudarlo a comprender lo que sucede durante el rastreo. (Advertencia: PDB lo colocará en el código fuente transformado por AutoGraph).

Rastreo de semántica

Reglas de la clave de caché

Una Function determina si se reutiliza una función concreta rastreada calculando una clave de caché a partir de los argumentos y kwargs de una entrada.

  • La clave generada para un argumento tf.Tensor es su forma y tipo d.
  • A partir de TensorFlow 2.3, la clave generada para un argumento tf.Variable es su id() .
  • La clave generada para una primitiva de Python es su valor. La clave generada para dict s anidados, list s, tuple s, namedtuple s y attr s es la tupla aplanada. (Como resultado de este aplanamiento, llamar a una función concreta con una estructura de anidamiento diferente a la utilizada durante el rastreo dará como resultado un TypeError).
  • Para todos los demás tipos de Python, las claves se basan en el id() objeto id() modo que los métodos se rastrean de forma independiente para cada instancia de una clase.

Controlar el retroceso

Retrazar ayuda a garantizar que TensorFlow genere gráficos correctos para cada conjunto de entradas. Sin embargo, el rastreo es una operación costosa. Si su Function traza un nuevo gráfico para cada llamada, encontrará que su código se ejecuta más lentamente que si no usara tf.function .

Para controlar el comportamiento de seguimiento, puede utilizar las siguientes técnicas:

  • Especifique input_signature en tf.function para limitar el seguimiento.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# We specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))

Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-20f544b8adbf>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-20f544b8adbf>", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))

  • Especifique una dimensión [Ninguno] en tf.TensorSpec para permitir flexibilidad en la reutilización de trazas.

    Dado que TensorFlow hace coincidir los tensores en función de su forma, el uso de una dimensión None como comodín permitirá que Function s reutilice trazas para entradas de tamaño variable. La entrada de tamaño variable puede ocurrir si tiene secuencias de diferente longitud o imágenes de diferentes tamaños para cada lote (consulte los tutoriales de Transformer y Deep Dream, por ejemplo).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))

Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

  • Transmita argumentos de Python a tensores para reducir el retroceso.

    A menudo, los argumentos de Python se utilizan para controlar hiperparámetros y construcciones de gráficos, por ejemplo, num_layers=10 o training=True o nonlinearity='relu' . Entonces, si el argumento de Python cambia, tiene sentido que tenga que volver sobre el gráfico.

    Sin embargo, es posible que no se esté utilizando un argumento de Python para controlar la construcción del gráfico. En estos casos, un cambio en el valor de Python puede desencadenar un retroceso innecesario. Tomemos, por ejemplo, este ciclo de entrenamiento, que AutoGraph desenrollará dinámicamente. A pesar de las múltiples trazas, el gráfico generado es realmente idéntico, por lo que no es necesario volver a rastrear.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

Si necesita forzar el retroceso, cree una nueva Function . Se garantiza que los objetos de Function separados no comparten trazas.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

Efectos secundarios de Python

Los efectos secundarios de Python, como imprimir, agregar a listas y mutaciones globales, solo ocurren la primera vez que llama a una Function con un conjunto de entradas. Luego, el tf.Graph trazado se vuelve a ejecutar, sin ejecutar el código Python.

La regla general es usar solo los efectos secundarios de Python para depurar sus rastros. De lo contrario, las operaciones de TensorFlow como tf.Variable.assign , tf.print y tf.summary son la mejor manera de garantizar que el tiempo de ejecución de TensorFlow tf.print y tf.summary su código con cada llamada.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)

Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

Muchas características de Python, como generadores e iteradores, dependen del tiempo de ejecución de Python para realizar un seguimiento del estado. En general, si bien estas construcciones funcionan como se esperaba en el modo ansioso, pueden suceder muchas cosas inesperadas dentro de una Function .

Para dar un ejemplo, avanzar en el estado del iterador es un efecto secundario de Python y, por lo tanto, solo ocurre durante el seguimiento.

external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)

Value of external_var: 0
Value of external_var: 0
Value of external_var: 0

Algunas construcciones de iteración son compatibles con AutoGraph. Consulte la sección sobre Transformaciones de AutoGraph para obtener una descripción general.

Si desea ejecutar código Python durante cada invocación de una Function , tf.py_function es una trampilla de salida. El inconveniente de tf.py_function es que no es portátil ni tiene un rendimiento particular, ni funciona bien en configuraciones distribuidas (multi-GPU, TPU). Además, dado que tf.py_function debe conectarse al gráfico, envía todas las entradas / salidas a los tensores.

Las API como tf.gather , tf.stack y tf.TensorArray pueden ayudarlo a implementar patrones de bucle comunes en TensorFlow nativo.

external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
# The list append happens all three times!
assert len(external_list) == 3
# The list contains tf.constant(1), not 1, because py_function casts everything to tensors.
assert external_list[0].numpy() == 1

Python side effect
Python side effect
Python side effect

Variables

Puede encontrar un error al crear una nueva tf.Variable en una función. Este error protege contra la divergencia de comportamiento en llamadas repetidas: en el modo ansioso, una función crea una nueva variable con cada llamada, pero en una Function , es posible que no se cree una nueva variable debido a la reutilización del seguimiento.

@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-73e410646579>", line 8, in <module>
    f(1.0)
ValueError: in user code:

    <ipython-input-1-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:702 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.


Puede crear variables dentro de una Function siempre que esas variables solo se creen la primera vez que se ejecuta la función.

class Count(tf.Module):
  def __init__(self):
    self.count = None
  
  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Otro error que puede encontrar es una variable de recolección de basura. A diferencia de las funciones normales de Python, las funciones concretas solo retienen WeakRefs para las variables sobre las que se cierran, por lo que debe conservar una referencia a cualquier variable.

external_var = tf.Variable(3)
@tf.function
def f(x):
  return x * external_var

traced_f = f.get_concrete_function(4)
print("Calling concrete function...")
print(traced_f(4))

del external_var
print()
print("Calling concrete function after garbage collecting its closed Variable...")
with assert_raises(tf.errors.FailedPreconditionError):
  traced_f(4)
Calling concrete function...
tf.Tensor(12, shape=(), dtype=int32)

Calling concrete function after garbage collecting its closed Variable...
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.FailedPreconditionError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-304a18524b57>", line 14, in <module>
    traced_f(4)
tensorflow.python.framework.errors_impl.FailedPreconditionError: 2 root error(s) found.
  (0) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-1-304a18524b57>:4) ]]
  (1) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-1-304a18524b57>:4) ]]
     [[ReadVariableOp/_2]]
0 successful operations.
0 derived errors ignored. [Op:__inference_f_514]

Function call stack:
f -> f


Transformaciones de AutoGraph

AutoGraph es una biblioteca que está tf.function forma predeterminada en tf.function y transforma un subconjunto del código ansioso de Python en operaciones de TensorFlow compatibles con gráficos. Esto incluye el flujo de control como if , for , while .

Las operaciones de TensorFlow como tf.cond y tf.while_loop continúan funcionando, pero el flujo de control suele ser más fácil de escribir y comprender cuando se escribe en Python.

# Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.224704742 0.895507693 0.0398198366 0.98112452 0.278468847]
[0.220997646 0.71410346 0.0397988036 0.753552318 0.271487355]
[0.217468739 0.61324358 0.0397778042 0.637263417 0.265008271]
[0.214104146 0.546406269 0.0397568382 0.563033342 0.258973926]
[0.210891485 0.497821957 0.0397359058 0.510224521 0.253335565]
[0.207819641 0.460402519 0.0397150069 0.470120102 0.248051569]
[0.204878598 0.430412233 0.0396941416 0.438296348 0.243086234]
[0.202059314 0.405665785 0.039673306 0.412231296 0.2384087]
[0.199353606 0.384786367 0.039652504 0.39036563 0.23399213]
[0.196754038 0.366856933 0.0396317355 0.371675402 0.229813099]
[0.194253832 0.351239443 0.039611 0.355456293 0.225851]
[0.191846803 0.337474287 0.0395902954 0.341205537 0.222087651]
[0.189527303 0.325220674 0.0395696238 0.3285532 0.218506947]
[0.187290132 0.314219803 0.0395489857 0.317220151 0.215094551]
[0.185130537 0.304271102 0.0395283774 0.30699119 0.211837649]
[0.183044136 0.295216352 0.0395078026 0.297697395 0.208724767]
[0.181026861 0.286928833 0.0394872613 0.289204 0.205745578]

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.17907499, 0.27930567, 0.03946675, 0.281402  , 0.20289075],
      dtype=float32)>

Si tiene curiosidad, puede inspeccionar el código que genera el autógrafo.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)


Condicionales

AutoGraph convertirá algunas declaraciones if <condition> en las llamadas tf.cond equivalentes. Esta sustitución se realiza si <condition> es un tensor. De lo contrario, la instrucción if se ejecuta como un condicional de Python.

Un condicional de Python se ejecuta durante el seguimiento, por lo que se agregará al gráfico exactamente una rama del condicional. Sin AutoGraph, este gráfico trazado no podría tomar la rama alternativa si hay un flujo de control dependiente de los datos.

tf.cond rastrea y agrega ambas ramas del condicional al gráfico, seleccionando dinámicamente una rama en el momento de la ejecución. El rastreo puede tener efectos secundarios no deseados; consulte Efectos de trazado de AutoGraph para obtener más información.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

Consulte la documentación de referencia para conocer las restricciones adicionales sobre las declaraciones if convertidas en AutoGraph.

Bucles

AutoGraph convertirá algunas declaraciones for y while en las operaciones de bucle de TensorFlow equivalentes, como tf.while_loop . tf.while_loop . Si no convertido, el for o while bucle se ejecuta como un bucle Python.

Esta sustitución se realiza en las siguientes situaciones:

Un bucle de Python se ejecuta durante el seguimiento, agregando operaciones adicionales al tf.Graph para cada iteración del bucle.

Un bucle de TensorFlow rastrea el cuerpo del bucle y selecciona dinámicamente cuántas iteraciones se ejecutarán en el momento de la ejecución. El cuerpo del bucle solo aparece una vez en el tf.Graph generado.

Consulte la documentación de referencia para conocer las restricciones adicionales sobre las declaraciones for y while convertidas en AutoGraph.

Bucle sobre datos de Python

Un error común es recorrer los datos de Python / Numpy dentro de una función tf.function . Este ciclo se ejecutará durante el proceso de rastreo, agregando una copia de su modelo al tf.Graph para cada iteración del ciclo.

Si desea envolver todo el ciclo de entrenamiento en tf.function , la forma más segura de hacerlo es envolver sus datos como un tf.data.Dataset para que AutoGraph desenrolle dinámicamente el ciclo de entrenamiento.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph

Al envolver datos de Python / Numpy en un conjunto de datos, tenga en cuenta tf.data.Dataset.from_generator frente a tf.data.Dataset.from_tensors . El primero mantendrá los datos en Python y los tf.py_function través de tf.py_function que puede tener implicaciones de rendimiento, mientras que el segundo tf.constant() una copia de los datos como un nodo tf.constant() grande en el gráfico, lo que puede tener implicaciones de memoria.

Lectura de datos de archivos a través de TFRecordDataset / CsvDataset / etc. es la forma más efectiva de consumir datos, ya que entonces TensorFlow mismo puede administrar la carga asincrónica y la captura previa de datos, sin tener que involucrar a Python. Para obtener más información, consulte la guía tf.data .

Acumulando valores en un bucle

Un patrón común es acumular valores intermedios de un bucle. Normalmente, esto se logra agregando a una lista de Python o agregando entradas a un diccionario de Python. Sin embargo, como estos son efectos secundarios de Python, no funcionarán como se esperaba en un ciclo desenrollado dinámicamente. Utilice tf.TensorArray para acumular resultados de un bucle desenrollado dinámicamente.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.9854791 , 0.5162524 , 0.14062047, 0.04950547],
        [1.8820469 , 0.67421603, 0.40786874, 0.7679055 ],
        [2.8815444 , 1.1567757 , 1.0627073 , 0.8880433 ]],

       [[0.94119024, 0.19776726, 0.24890792, 0.4663092 ],
        [1.4591933 , 1.123581  , 0.35438073, 1.4392309 ],
        [2.0026946 , 1.9165647 , 0.37988353, 1.8128917 ]]], dtype=float32)>

Otras lecturas

Para obtener información sobre cómo exportar y cargar una Function , consulte la guía de modelo guardado . Para obtener más información sobre las optimizaciones de gráficos que se realizan después del seguimiento, consulte la guía de Grappler . Para saber cómo optimizar su canalización de datos y perfilar su modelo, consulte la guía Profiler .