Introduction aux graphes et tf.function

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Aperçu

Ce guide va sous la surface de TensorFlow et Keras pour montrer comment fonctionne TensorFlow. Si vous souhaitez plutôt commencer immédiatement avec Keras, consultez la collection de guides Keras .

Dans ce guide, vous apprendrez comment TensorFlow vous permet d'apporter des modifications simples à votre code pour obtenir des graphiques, comment les graphiques sont stockés et représentés, et comment vous pouvez les utiliser pour accélérer vos modèles.

Il s'agit d'un aperçu général qui explique comment tf.function vous permet de passer d'une exécution rapide à une exécution graphique. Pour une spécification plus complète de tf.function , consultez le guide tf.function .

Que sont les graphiques ?

Dans les trois guides précédents, vous avez exécuté TensorFlow avec impatience . Cela signifie que les opérations TensorFlow sont exécutées par Python, opération par opération, et renvoient les résultats à Python.

Alors que l'exécution hâtive présente plusieurs avantages uniques, l'exécution de graphes permet la portabilité en dehors de Python et tend à offrir de meilleures performances. L'exécution de graphe signifie que les calculs de tenseur sont exécutés sous la forme d'un graphe TensorFlow , parfois appelé tf.Graph ou simplement "graphe".

Les graphes sont des structures de données qui contiennent un ensemble d'objets tf.Operation , qui représentent des unités de calcul ; et les objets tf.Tensor , qui représentent les unités de données qui circulent entre les opérations. Ils sont définis dans un contexte tf.Graph . Étant donné que ces graphiques sont des structures de données, ils peuvent être enregistrés, exécutés et restaurés sans le code Python d'origine.

Voici à quoi ressemble un graphique TensorFlow représentant un réseau de neurones à deux couches lorsqu'il est visualisé dans TensorBoard.

Un graphique TensorFlow simple

Les avantages des graphiques

Avec un graphique, vous avez une grande flexibilité. Vous pouvez utiliser votre graphique TensorFlow dans des environnements qui ne disposent pas d'un interpréteur Python, comme les applications mobiles, les appareils intégrés et les serveurs backend. TensorFlow utilise des graphiques comme format pour les modèles enregistrés lorsqu'il les exporte depuis Python.

Les graphiques sont également facilement optimisés, permettant au compilateur d'effectuer des transformations telles que :

  • Déduisez statiquement la valeur des tenseurs en repliant les nœuds constants dans votre calcul ("pliage constant") .
  • Séparez les sous-parties d'un calcul qui sont indépendantes et répartissez-les entre les threads ou les appareils.
  • Simplifiez les opérations arithmétiques en éliminant les sous-expressions courantes.

Il existe un système d'optimisation complet, Grappler , pour effectuer cela et d'autres accélérations.

En bref, les graphiques sont extrêmement utiles et permettent à votre TensorFlow de fonctionner rapidement , de fonctionner en parallèle et de fonctionner efficacement sur plusieurs appareils .

Cependant, vous souhaitez toujours définir vos modèles d'apprentissage automatique (ou d'autres calculs) en Python pour plus de commodité, puis construire automatiquement des graphiques lorsque vous en avez besoin.

Installer

import tensorflow as tf
import timeit
from datetime import datetime

Profiter des graphiques

Vous créez et exécutez un graphique dans TensorFlow en utilisant tf.function , soit en tant qu'appel direct, soit en tant que décorateur. tf.function prend une fonction régulière en entrée et renvoie une Function . Une Function est un appelable Python qui crée des graphiques TensorFlow à partir de la fonction Python. Vous utilisez une Function de la même manière que son équivalent Python.

# Define a Python function.
def a_regular_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# `a_function_that_uses_a_graph` is a TensorFlow `Function`.
a_function_that_uses_a_graph = tf.function(a_regular_function)

# Make some tensors.
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)

orig_value = a_regular_function(x1, y1, b1).numpy()
# Call a `Function` like a Python function.
tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()
assert(orig_value == tf_function_value)

À l'extérieur, une Function ressemble à une fonction normale que vous écrivez à l'aide d'opérations TensorFlow. En dessous , en revanche, c'est très différent . Une Function encapsule plusieurs tf.Graph derrière une API . C'est ainsi que Function est en mesure de vous offrir les avantages de l'exécution de graphes , comme la vitesse et la déployabilité.

tf.function s'applique à une fonction et à toutes les autres fonctions qu'elle appelle :

def inner_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# Use the decorator to make `outer_function` a `Function`.
@tf.function
def outer_function(x):
  y = tf.constant([[2.0], [3.0]])
  b = tf.constant(4.0)

  return inner_function(x, y, b)

# Note that the callable will create a graph that
# includes `inner_function` as well as `outer_function`.
outer_function(tf.constant([[1.0, 2.0]])).numpy()
array([[12.]], dtype=float32)

Si vous avez utilisé TensorFlow 1.x, vous remarquerez qu'à aucun moment vous n'avez eu besoin de définir un Placeholder ou tf.Session .

Conversion de fonctions Python en graphes

Toute fonction que vous écrivez avec TensorFlow contiendra un mélange d'opérations TF intégrées et de logique Python, telles que des clauses if-then , des boucles, break , return , continue , etc. Alors que les opérations TensorFlow sont facilement capturées par un tf.Graph , la logique spécifique à Python doit subir une étape supplémentaire pour faire partie du graphique. tf.function utilise une bibliothèque appelée AutoGraph ( tf.autograph ) pour convertir le code Python en code générateur de graphes.

def simple_relu(x):
  if tf.greater(x, 0):
    return x
  else:
    return 0

# `tf_simple_relu` is a TensorFlow `Function` that wraps `simple_relu`.
tf_simple_relu = tf.function(simple_relu)

print("First branch, with graph:", tf_simple_relu(tf.constant(1)).numpy())
print("Second branch, with graph:", tf_simple_relu(tf.constant(-1)).numpy())
First branch, with graph: 1
Second branch, with graph: 0

Bien qu'il soit peu probable que vous ayez besoin d'afficher directement les graphiques, vous pouvez inspecter les sorties pour vérifier les résultats exacts. Celles-ci ne sont pas faciles à lire, donc pas besoin de regarder trop attentivement !

# This is the graph-generating output of AutoGraph.
print(tf.autograph.to_code(simple_relu))
def tf__simple_relu(x):
    with ag__.FunctionScope('simple_relu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (do_return, retval_)

        def set_state(vars_):
            nonlocal retval_, do_return
            (do_return, retval_) = vars_

        def if_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = ag__.ld(x)
            except:
                do_return = False
                raise

        def else_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = 0
            except:
                do_return = False
                raise
        ag__.if_stmt(ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(x), 0), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
        return fscope.ret(retval_, do_return)
# This is the graph itself.
print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())
node {
  name: "x"
  op: "Placeholder"
  attr {
    key: "_user_specified_name"
    value {
      s: "x"
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "Greater/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
}
node {
  name: "Greater"
  op: "Greater"
  input: "x"
  input: "Greater/y"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "cond"
  op: "StatelessIf"
  input: "Greater"
  input: "x"
  attr {
    key: "Tcond"
    value {
      type: DT_BOOL
    }
  }
  attr {
    key: "Tin"
    value {
      list {
        type: DT_INT32
      }
    }
  }
  attr {
    key: "Tout"
    value {
      list {
        type: DT_BOOL
        type: DT_INT32
      }
    }
  }
  attr {
    key: "_lower_using_switch_merge"
    value {
      b: true
    }
  }
  attr {
    key: "_read_only_resource_inputs"
    value {
      list {
      }
    }
  }
  attr {
    key: "else_branch"
    value {
      func {
        name: "cond_false_34"
      }
    }
  }
  attr {
    key: "output_shapes"
    value {
      list {
        shape {
        }
        shape {
        }
      }
    }
  }
  attr {
    key: "then_branch"
    value {
      func {
        name: "cond_true_33"
      }
    }
  }
}
node {
  name: "cond/Identity"
  op: "Identity"
  input: "cond"
  attr {
    key: "T"
    value {
      type: DT_BOOL
    }
  }
}
node {
  name: "cond/Identity_1"
  op: "Identity"
  input: "cond:1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "Identity"
  op: "Identity"
  input: "cond/Identity_1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
library {
  function {
    signature {
      name: "cond_false_34"
      input_arg {
        name: "cond_placeholder"
        type: DT_INT32
      }
      output_arg {
        name: "cond_identity"
        type: DT_BOOL
      }
      output_arg {
        name: "cond_identity_1"
        type: DT_INT32
      }
    }
    node_def {
      name: "cond/Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Const_1"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Const_2"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 0
          }
        }
      }
    }
    node_def {
      name: "cond/Const_3"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Identity"
      op: "Identity"
      input: "cond/Const_3:output:0"
      attr {
        key: "T"
        value {
          type: DT_BOOL
        }
      }
    }
    node_def {
      name: "cond/Const_4"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 0
          }
        }
      }
    }
    node_def {
      name: "cond/Identity_1"
      op: "Identity"
      input: "cond/Const_4:output:0"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
    }
    ret {
      key: "cond_identity"
      value: "cond/Identity:output:0"
    }
    ret {
      key: "cond_identity_1"
      value: "cond/Identity_1:output:0"
    }
    attr {
      key: "_construction_context"
      value {
        s: "kEagerRuntime"
      }
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_output_shapes"
          value {
            list {
              shape {
              }
            }
          }
        }
      }
    }
  }
  function {
    signature {
      name: "cond_true_33"
      input_arg {
        name: "cond_identity_1_x"
        type: DT_INT32
      }
      output_arg {
        name: "cond_identity"
        type: DT_BOOL
      }
      output_arg {
        name: "cond_identity_1"
        type: DT_INT32
      }
    }
    node_def {
      name: "cond/Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Identity"
      op: "Identity"
      input: "cond/Const:output:0"
      attr {
        key: "T"
        value {
          type: DT_BOOL
        }
      }
    }
    node_def {
      name: "cond/Identity_1"
      op: "Identity"
      input: "cond_identity_1_x"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
    }
    ret {
      key: "cond_identity"
      value: "cond/Identity:output:0"
    }
    ret {
      key: "cond_identity_1"
      value: "cond/Identity_1:output:0"
    }
    attr {
      key: "_construction_context"
      value {
        s: "kEagerRuntime"
      }
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_output_shapes"
          value {
            list {
              shape {
              }
            }
          }
        }
      }
    }
  }
}
versions {
  producer: 898
  min_consumer: 12
}

La plupart du temps, tf.function fonctionnera sans considération particulière. Cependant, il y a quelques mises en garde, et le guide tf.function peut vous aider ici, ainsi que la référence complète d'AutoGraph

Polymorphisme : une Function , plusieurs graphes

Un tf.Graph est spécialisé dans un type spécifique d'entrées (par exemple, des tenseurs avec un dtype spécifique ou des objets avec le même id() ).

Chaque fois que vous appelez une Function avec de nouveaux dtypes et formes dans ses arguments, Function crée un nouveau tf.Graph pour les nouveaux arguments. Les dtypes et les formes des entrées d'un tf.Graph sont connus comme une signature d'entrée ou simplement une signature .

La Function stocke le tf.Graph correspondant à cette signature dans une ConcreteFunction . Une ConcreteFunction est un wrapper autour d'un tf.Graph .

@tf.function
def my_relu(x):
  return tf.maximum(0., x)

# `my_relu` creates new graphs as it observes more signatures.
print(my_relu(tf.constant(5.5)))
print(my_relu([1, -1]))
print(my_relu(tf.constant([3., -3.])))
tf.Tensor(5.5, shape=(), dtype=float32)
tf.Tensor([1. 0.], shape=(2,), dtype=float32)
tf.Tensor([3. 0.], shape=(2,), dtype=float32)

Si la Function a déjà été appelée avec cette signature, Function ne crée pas de nouveau tf.Graph .

# These two calls do *not* create new graphs.
print(my_relu(tf.constant(-2.5))) # Signature matches `tf.constant(5.5)`.
print(my_relu(tf.constant([-1., 1.]))) # Signature matches `tf.constant([3., -3.])`.
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor([0. 1.], shape=(2,), dtype=float32)

Parce qu'elle est soutenue par plusieurs graphiques, une Function est polymorphe . Cela lui permet de prendre en charge plus de types d'entrée qu'un seul tf.Graph pourrait représenter, ainsi que d'optimiser chaque tf.Graph pour de meilleures performances.

# There are three `ConcreteFunction`s (one for each graph) in `my_relu`.
# The `ConcreteFunction` also knows the return type and shape!
print(my_relu.pretty_printed_concrete_signatures())
my_relu(x)
  Args:
    x: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

my_relu(x=[1, -1])
  Returns:
    float32 Tensor, shape=(2,)

my_relu(x)
  Args:
    x: float32 Tensor, shape=(2,)
  Returns:
    float32 Tensor, shape=(2,)

Utilisation tf.function

Jusqu'à présent, vous avez appris à convertir une fonction Python en un graphique simplement en utilisant tf.function comme décorateur ou wrapper. Mais en pratique, faire fonctionner correctement tf.function peut être délicat ! Dans les sections suivantes, vous apprendrez comment faire fonctionner votre code comme prévu avec tf.function .

Exécution de graphe vs exécution impatiente

Le code d'une Function peut être exécuté à la fois avec impatience et sous forme de graphique. Par défaut, Function exécute son code sous forme de graphe :

@tf.function
def get_MSE(y_true, y_pred):
  sq_diff = tf.pow(y_true - y_pred, 2)
  return tf.reduce_mean(sq_diff)
y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)
y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)
print(y_true)
print(y_pred)
tf.Tensor([1 0 4 4 7], shape=(5,), dtype=int32)
tf.Tensor([3 6 3 0 6], shape=(5,), dtype=int32)
get_MSE(y_true, y_pred)
<tf.Tensor: shape=(), dtype=int32, numpy=11>

Pour vérifier que le graphique de votre Function effectue le même calcul que sa fonction Python équivalente, vous pouvez le faire exécuter avec empressement avec tf.config.run_functions_eagerly(True) . Il s'agit d'un commutateur qui désactive la capacité de Function à créer et exécuter des graphiques , au lieu d'exécuter le code normalement.

tf.config.run_functions_eagerly(True)
get_MSE(y_true, y_pred)
<tf.Tensor: shape=(), dtype=int32, numpy=11>
# Don't forget to set it back when you are done.
tf.config.run_functions_eagerly(False)

Cependant, Function peut se comporter différemment sous une exécution graphique et impatiente. La fonction print Python est un exemple de la différence entre ces deux modes. Voyons ce qui se passe lorsque vous insérez une instruction print dans votre fonction et que vous l'appelez à plusieurs reprises.

@tf.function
def get_MSE(y_true, y_pred):
  print("Calculating MSE!")
  sq_diff = tf.pow(y_true - y_pred, 2)
  return tf.reduce_mean(sq_diff)

Observez ce qui est imprimé :

error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
Calculating MSE!

Le rendu est-il surprenant ? get_MSE n'a été imprimé qu'une seule fois même s'il a été appelé trois fois.

Pour expliquer, l'instruction print est exécutée lorsque Function exécute le code d'origine afin de créer le graphique dans un processus appelé "tracing" . Le traçage capture les opérations TensorFlow dans un graphique et l' print n'est pas capturée dans le graphique. Ce graphique est ensuite exécuté pour les trois appels sans jamais exécuter à nouveau le code Python .

Pour vérifier l'intégrité, désactivons l'exécution du graphique pour comparer :

# Now, globally set everything to run eagerly to force eager execution.
tf.config.run_functions_eagerly(True)
# Observe what is printed below.
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
Calculating MSE!
Calculating MSE!
Calculating MSE!
tf.config.run_functions_eagerly(False)

print est un effet secondaire de Python et il existe d'autres différences dont vous devez être conscient lors de la conversion d'une fonction en Function . Pour en savoir plus, consultez la section Limitations du guide Meilleures performances avec tf.function .

Exécution non stricte

L'exécution du graphe n'exécute que les opérations nécessaires pour produire les effets observables, notamment :

  • La valeur de retour de la fonction
  • Effets secondaires bien connus documentés tels que :

Ce comportement est généralement connu sous le nom d '«exécution non stricte» et diffère de l'exécution hâtive, qui parcourt toutes les opérations du programme, nécessaires ou non.

En particulier, la vérification des erreurs d'exécution ne compte pas comme un effet observable. Si une opération est ignorée parce qu'elle n'est pas nécessaire, elle ne peut générer aucune erreur d'exécution.

Dans l'exemple suivant, l'opération "inutile" tf.gather est ignorée lors de l'exécution du graphe, de sorte que l'erreur d'exécution InvalidArgumentError n'est pas déclenchée comme elle le serait dans une exécution hâtive. Ne vous fiez pas à une erreur générée lors de l'exécution d'un graphique.

def unused_return_eager(x):
  # Get index 1 will fail when `len(x) == 1`
  tf.gather(x, [1]) # unused 
  return x

try:
  print(unused_return_eager(tf.constant([0.0])))
except tf.errors.InvalidArgumentError as e:
  # All operations are run during eager execution so an error is raised.
  print(f'{type(e).__name__}: {e}')
tf.Tensor([0.], shape=(1,), dtype=float32)
@tf.function
def unused_return_graph(x):
  tf.gather(x, [1]) # unused
  return x

# Only needed operations are run during graph exection. The error is not raised.
print(unused_return_graph(tf.constant([0.0])))
tf.Tensor([0.], shape=(1,), dtype=float32)

meilleures pratiques tf.function

Cela peut prendre un certain temps pour s'habituer au comportement de Function . Pour démarrer rapidement, les utilisateurs novices doivent s'amuser à décorer des fonctions de jouet avec @tf.function pour acquérir de l'expérience en passant de l'exécution impatiente à l'exécution graphique.

Concevoir pour tf.function peut être votre meilleur pari pour écrire des programmes TensorFlow compatibles avec les graphes. Voici quelques conseils:

  • Basculez entre l'exécution rapide et l'exécution graphique tôt et souvent avec tf.config.run_functions_eagerly pour déterminer si/quand les deux modes divergent.
  • Créez tf.Variable en dehors de la fonction Python et modifiez-les à l'intérieur. Il en va de même pour les objets qui utilisent tf.Variable , comme keras.layers , keras.Model s et tf.optimizers .
  • Évitez d'écrire des fonctions qui dépendent de variables Python externes , à l'exclusion tf.Variable s et Keras.
  • Préférez écrire des fonctions qui prennent des tenseurs et d'autres types de TensorFlow en entrée. Vous pouvez passer d'autres types d'objets mais attention !
  • Incluez autant de calculs que possible sous une tf.function pour maximiser le gain de performances. Par exemple, décorez toute une étape d'entraînement ou toute la boucle d'entraînement.

Voir l'accélération

tf.function améliore généralement les performances de votre code, mais l'accélération dépend du type de calcul que vous exécutez. Les petits calculs peuvent être dominés par la surcharge liée à l'appel d'un graphe. Vous pouvez mesurer la différence de performances comme suit :

x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)

def power(x, y):
  result = tf.eye(10, dtype=tf.dtypes.int32)
  for _ in range(y):
    result = tf.matmul(x, result)
  return result
print("Eager execution:", timeit.timeit(lambda: power(x, 100), number=1000))
Eager execution: 2.5637862179974036
power_as_graph = tf.function(power)
print("Graph execution:", timeit.timeit(lambda: power_as_graph(x, 100), number=1000))
Graph execution: 0.6832536700021592

tf.function est couramment utilisé pour accélérer les boucles de formation, et vous pouvez en savoir plus à ce sujet dans Rédaction d'une boucle de formation à partir de zéro avec Keras.

Performances et compromis

Les graphes peuvent accélérer votre code, mais leur processus de création entraîne des frais généraux. Pour certaines fonctions, la création du graphe prend plus de temps que l'exécution du graphe. Cet investissement est généralement rapidement amorti grâce à l'amélioration des performances des exécutions ultérieures, mais il est important de savoir que les premières étapes de toute formation de modèle volumineux peuvent être plus lentes en raison du traçage.

Quelle que soit la taille de votre modèle, vous souhaitez éviter de tracer fréquemment. Le guide tf.function explique comment définir des spécifications d'entrée et utiliser des arguments de tenseur pour éviter de retracer. Si vous constatez que vous obtenez des performances anormalement médiocres, c'est une bonne idée de vérifier si vous retracez accidentellement.

A quand un suivi de Function ?

Pour savoir quand votre Function trace, ajoutez une instruction d' print à son code. En règle générale, Function exécute l'instruction print à chaque fois qu'il trace.

@tf.function
def a_function_with_python_side_effect(x):
  print("Tracing!") # An eager-only side effect.
  return x * x + tf.constant(2)

# This is traced the first time.
print(a_function_with_python_side_effect(tf.constant(2)))
# The second time through, you won't see the side effect.
print(a_function_with_python_side_effect(tf.constant(3)))
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)
# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter.
print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
Tracing!
tf.Tensor(11, shape=(), dtype=int32)

Les nouveaux arguments Python déclenchent toujours la création d'un nouveau graphique, d'où le traçage supplémentaire.

Prochaines étapes

Vous pouvez en savoir plus sur tf.function sur la page de référence de l'API et en suivant le guide Meilleures performances avec tf.function .