Veja no TensorFlow.org | Executar no Google Colab | Ver fonte no GitHub | Baixar caderno |
No TensorFlow 2, a execução antecipada é ativada por padrão. A interface do usuário é intuitiva e flexível (executar operações pontuais é muito mais fácil e rápido), mas isso pode prejudicar o desempenho e a capacidade de implantação.
Você pode usar tf.function
para fazer gráficos de seus programas. É uma ferramenta de transformação que cria gráficos de fluxo de dados independentes de Python a partir de seu código Python. Isso ajudará você a criar modelos portáteis e de alto desempenho, e é necessário usar SavedModel
.
Este guia irá ajudá-lo a conceituar como o tf.function
funciona nos bastidores, para que você possa usá-lo de forma eficaz.
As principais dicas e recomendações são:
- Depure no modo ansioso e depois decore com
@tf.function
. - Não confie nos efeitos colaterais do Python, como mutação de objeto ou anexos de lista.
-
tf.function
funciona melhor com operações do TensorFlow; As chamadas NumPy e Python são convertidas em constantes.
Configurar
import tensorflow as tf
Defina uma função auxiliar para demonstrar os tipos de erros que você pode encontrar:
import traceback
import contextlib
# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
try:
yield
except error_class as e:
print('Caught expected exception \n {}:'.format(error_class))
traceback.print_exc(limit=2)
except Exception as e:
raise e
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))
Fundamentos
Uso
Uma Function
que você define (por exemplo, aplicando o decorador @tf.function
) é como uma operação principal do TensorFlow: você pode executá-la rapidamente; você pode calcular gradientes; e assim por diante.
@tf.function # The decorator converts `add` into a `Function`.
def add(a, b):
return a + b
add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy= array([[2., 2.], [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
Você pode usar Function
dentro de outras Function
.
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy= array([[3., 3.], [3., 3.], [3., 3.]], dtype=float32)>
Function
s pode ser mais rápida que o código ansioso, especialmente para gráficos com muitas operações pequenas. Mas para gráficos com algumas operações caras (como convoluções), você pode não ver muita aceleração.
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)
@tf.function
def conv_fn(image):
return conv_layer(image)
image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006058974999177735 Function conv: 0.005791576000774512 Note how there's not much difference in performance for convolutions
Rastreamento
Esta seção expõe como o Function
funciona nos bastidores, incluindo detalhes de implementação que podem mudar no futuro . No entanto, uma vez que você entenda por que e quando o rastreamento acontece, é muito mais fácil usar o tf.function
efetivamente!
O que é "rastreamento"?
Uma Function
executa seu programa em um gráfico do TensorFlow . No entanto, um tf.Graph
não pode representar todas as coisas que você escreveria em um programa TensorFlow ansioso. Por exemplo, Python suporta polimorfismo, mas tf.Graph
requer que suas entradas tenham um tipo de dados e uma dimensão especificados. Ou você pode realizar tarefas secundárias como ler argumentos de linha de comando, gerar um erro ou trabalhar com um objeto Python mais complexo; nenhuma dessas coisas pode ser executada em um tf.Graph
.
Function
preenche essa lacuna separando seu código em dois estágios:
1) Na primeira etapa, chamada de " tracing ", Function
cria um novo tf.Graph
. O código Python é executado normalmente, mas todas as operações do TensorFlow (como adicionar dois tensores) são adiadas : elas são capturadas pelo tf.Graph
e não são executadas.
2) Na segunda etapa, é executado um tf.Graph
que contém tudo o que foi adiado na primeira etapa. Este estágio é muito mais rápido que o estágio de rastreamento.
Dependendo de suas entradas, Function
nem sempre executará o primeiro estágio quando for chamada. Consulte "Regras de rastreamento" abaixo para ter uma noção melhor de como ele faz essa determinação. Ignorar o primeiro estágio e executar apenas o segundo estágio é o que oferece o alto desempenho do TensorFlow.
Quando Function
decide rastrear, o estágio de rastreamento é imediatamente seguido pelo segundo estágio, portanto, chamar a Function
cria e executa o tf.Graph
. Mais tarde, você verá como pode executar apenas o estágio de rastreamento com get_concrete_function
.
Quando você passa argumentos de tipos diferentes para um Function
, ambos os estágios são executados:
@tf.function
def double(a):
print("Tracing with", a)
return a + a
print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32) Tracing with Tensor("a:0", shape=(), dtype=float32) tf.Tensor(2.2, shape=(), dtype=float32) Tracing with Tensor("a:0", shape=(), dtype=string) tf.Tensor(b'aa', shape=(), dtype=string)
Observe que, se você chamar repetidamente uma Function
com o mesmo tipo de argumento, o TensorFlow pulará o estágio de rastreamento e reutilizará um gráfico rastreado anteriormente, pois o gráfico gerado seria idêntico.
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)
Você pode usar pretty_printed_concrete_signatures()
para ver todos os traços disponíveis:
print(double.pretty_printed_concrete_signatures())
double(a) Args: a: int32 Tensor, shape=() Returns: int32 Tensor, shape=() double(a) Args: a: float32 Tensor, shape=() Returns: float32 Tensor, shape=() double(a) Args: a: string Tensor, shape=() Returns: string Tensor, shape=()
Até agora, você viu que tf.function
cria uma camada de despacho dinâmica em cache sobre a lógica de rastreamento de gráfico do TensorFlow. Para ser mais específico sobre a terminologia:
- Um
tf.Graph
é a representação bruta, independente de linguagem e portátil de uma computação do TensorFlow. - Um
ConcreteFunction
envolve umtf.Graph
. - Uma
Function
gerencia um cache deConcreteFunction
se escolhe o caminho certo para suas entradas. -
tf.function
envolve uma função Python, retornando um objetoFunction
. - O rastreamento cria um
tf.Graph
e o envolve em umConcreteFunction
, também conhecido como rastreamento.
Regras de rastreamento
Uma Function
determina se deve ser reutilizada uma ConcreteFunction
rastreada calculando uma chave de cache dos argumentos e kwargs de uma entrada. Uma chave de cache é uma chave que identifica uma ConcreteFunction
com base nos argumentos e kwargs de entrada da chamada da Function
, de acordo com as seguintes regras (que podem mudar):
- A chave gerada para um
tf.Tensor
é sua forma e dtype. - A chave gerada para um
tf.Variable
é um id de variável único. - A chave gerada para uma primitiva Python (como
int
,float
,str
) é seu valor. - A chave gerada para
dict
s aninhados,list
s,tuple
s,namedtuple
s eattr
s é a tupla achatada de chaves-folha (vejanest.flatten
). (Como resultado desse achatamento, chamar uma função concreta com uma estrutura de aninhamento diferente daquela usada durante o rastreamento resultará em um TypeError). - Para todos os outros tipos de Python, a chave é exclusiva do objeto. Dessa forma, uma função ou método é rastreado independentemente para cada instância com a qual é chamado.
Controlando o retraçamento
O retraçamento, que é quando sua Function
cria mais de um rastreamento, ajuda a garantir que o TensorFlow gere gráficos corretos para cada conjunto de entradas. No entanto, o rastreamento é uma operação cara! Se sua Function
refazer um novo gráfico para cada chamada, você descobrirá que seu código executa mais lentamente do que se você não usasse tf.function
.
Para controlar o comportamento de rastreamento, você pode usar as seguintes técnicas:
- Especifique
input_signature
emtf.function
para limitar o rastreamento.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(x % 2 == 0, x // 2, 3 * x + 1)
print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([[1, 2], [3, 4]]))
# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([4 1], shape=(2,), dtype=int32) Caught expected exception <class 'ValueError'>: Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/1851403433.py", line 9, in <module> next_collatz(tf.constant([[1, 2], [3, 4]])) ValueError: Python inputs incompatible with input_signature: inputs: ( tf.Tensor( [[1 2] [3 4]], shape=(2, 2), dtype=int32)) input_signature: ( TensorSpec(shape=(None,), dtype=tf.int32, name=None)). Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/1851403433.py", line 13, in <module> next_collatz(tf.constant([1.0, 2.0])) ValueError: Python inputs incompatible with input_signature: inputs: ( tf.Tensor([1. 2.], shape=(2,), dtype=float32)) input_signature: ( TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Especifique uma dimensão [None] em
tf.TensorSpec
para permitir flexibilidade na reutilização de rastreamento.Como o TensorFlow corresponde aos tensores com base em sua forma, usar uma dimensão
None
como curinga permitirá queFunction
s reutilize traços para entrada de tamanho variável. A entrada de tamanho variável pode ocorrer se você tiver sequências de comprimento diferente ou imagens de tamanhos diferentes para cada lote (consulte os tutoriais Transformer e Deep Dream , por exemplo).
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
print('Tracing with', x)
return x
# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([1 2 3], shape=(3,), dtype=int32) tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
Transmita argumentos do Python para tensores para reduzir o retracing.
Frequentemente, os argumentos do Python são usados para controlar hiperparâmetros e construções de gráficos - por exemplo,
num_layers=10
outraining=True
ounonlinearity='relu'
. Portanto, se o argumento do Python mudar, faz sentido que você tenha que refazer o gráfico.No entanto, é possível que um argumento do Python não esteja sendo usado para controlar a construção do gráfico. Nesses casos, uma alteração no valor do Python pode desencadear um retraçamento desnecessário. Veja, por exemplo, este loop de treinamento, que o AutoGraph desenrolará dinamicamente. Apesar dos vários rastreamentos, o gráfico gerado é realmente idêntico, portanto, o retraçamento é desnecessário.
def train_one_step():
pass
@tf.function
def train(num_steps):
print("Tracing with num_steps = ", num_steps)
tf.print("Executing with num_steps = ", num_steps)
for _ in tf.range(num_steps):
train_one_step()
print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)
print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments. Tracing with num_steps = 10 Executing with num_steps = 10 Tracing with num_steps = 20 Executing with num_steps = 20 Traces are reused for Tensor arguments. Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32) Executing with num_steps = 10 Executing with num_steps = 20
Se você precisar forçar o retracing, crie uma nova Function
. Objetos Function
separados são garantidos para não compartilhar rastreamentos.
def f():
print('Tracing!')
tf.print('Executing')
tf.function(f)()
tf.function(f)()
Tracing! Executing Tracing! Executing
Obtenção de funções concretas
Cada vez que uma função é traçada, uma nova função concreta é criada. Você pode obter diretamente uma função concreta, usando get_concrete_function
.
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace Executing traced function tf.Tensor(b'aa', shape=(), dtype=string) tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)
A impressão de uma ConcreteFunction
exibe um resumo de seus argumentos de entrada (com tipos) e seu tipo de saída.
print(double_strings)
ConcreteFunction double(a) Args: a: string Tensor, shape=() Returns: string Tensor, shape=()
Você também pode recuperar diretamente a assinatura de uma função concreta.
print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {}) Tensor("Identity:0", shape=(), dtype=string)
Usar um rastreamento concreto com tipos incompatíveis gerará um erro
with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))
Caught expected exception <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/3196284684.py", line 2, in <module> double_strings(tf.constant(1)) tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]
Você pode notar que os argumentos do Python recebem tratamento especial na assinatura de entrada de uma função concreta. Antes do TensorFlow 2.3, os argumentos do Python eram simplesmente removidos da assinatura da função concreta. A partir do TensorFlow 2.3, os argumentos do Python permanecem na assinatura, mas são restritos ao valor definido durante o rastreamento.
@tf.function
def pow(a, b):
return a ** b
square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2) Args: a: float32 Tensor, shape=<unknown> Returns: float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100
with assert_raises(TypeError):
square(tf.constant(10.0), b=3)
Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1721, in _call_impl cancellation_manager) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1765, in _call_with_flat_signature raise TypeError(f"{self._flat_signature_summary()} got unexpected " TypeError: pow(a) got unexpected keyword arguments: b. During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/2310937119.py", line 4, in <module> square(tf.constant(10.0), b=3) TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.
Obtendo gráficos
Cada função concreta é um wrapper que pode ser chamado em torno de um tf.Graph
. Embora recuperar o objeto tf.Graph
real não seja algo que você normalmente precise fazer, você pode obtê-lo facilmente de qualquer função concreta.
graph = double_strings.graph
for node in graph.as_graph_def().node:
print(f'{node.input} -> {node.name}')
[] -> a ['a', 'a'] -> add ['add'] -> Identity
Depuração
Em geral, depurar código é mais fácil no modo ansioso do que dentro de tf.function
. Você deve garantir que seu código seja executado sem erros no modo ansioso antes de decorar com tf.function
. Para auxiliar no processo de depuração, você pode chamar tf.config.run_functions_eagerly(True)
para desabilitar e reativar globalmente tf.function
.
Ao rastrear problemas que aparecem apenas em tf.function
, aqui estão algumas dicas:
- Chamadas de
print
Python antigas simples são executadas apenas durante o rastreamento, ajudando você a rastrear quando sua função é (re)rastreada. - As chamadas
tf.print
serão executadas todas as vezes e podem ajudá-lo a rastrear valores intermediários durante a execução. -
tf.debugging.enable_check_numerics
é uma maneira fácil de rastrear onde NaNs e Inf são criados. -
pdb
(o depurador Python ) pode ajudá-lo a entender o que está acontecendo durante o rastreamento. (Aviso: opdb
o colocará no código-fonte transformado em AutoGraph.)
Transformações do AutoGraph
AutoGraph é uma biblioteca que está ativada por padrão em tf.function
e transforma um subconjunto de código Python em operações TensorFlow compatíveis com gráficos. Isso inclui fluxo de controle como if
, for
, while
.
Operações do TensorFlow como tf.cond
e tf.while_loop
continuam funcionando, mas o fluxo de controle geralmente é mais fácil de escrever e entender quando escrito em Python.
# A simple loop
@tf.function
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
f(tf.random.uniform([5]))
[0.666458249 0.713946581 0.723879576 0.330758929 0.184087753] [0.582645297 0.613145649 0.619306684 0.319202513 0.182036072] [0.524585426 0.546337605 0.550645113 0.308785647 0.18005164] [0.481231302 0.497770309 0.501003504 0.299331933 0.178130865] [0.447229207 0.460361809 0.462906033 0.290701121 0.176270396] [0.419618756 0.430379033 0.432449728 0.282779962 0.174467146] [0.396609187 0.405638 0.407366514 0.275476 0.172718227] [0.377043903 0.384762734 0.386234313 0.268712848 0.17102097] [0.360137492 0.366836458 0.368109286 0.262426734 0.169372901] [0.345335096 0.351221472 0.352336824 0.256563932 0.167771652] [0.332231969 0.337458342 0.338446289 0.251078814 0.166215062] [0.320524871 0.325206399 0.326089561 0.24593246 0.164701089] [0.309981436 0.314206958 0.31500268 0.241091311 0.163227797] [0.300420195 0.304259449 0.304981351 0.236526251 0.161793426] [0.291697085 0.295205742 0.295864582 0.232211992 0.160396278] [0.283696055 0.286919087 0.287523568 0.228126258 0.159034774] [0.276322395 0.279296666 0.27985391 0.224249557 0.157707423] [0.269497961 0.272254 0.272769839 0.220564634 0.15641281] [0.263157606 0.265720904 0.266200244 0.21705614 0.155149609] [0.257246554 0.259638608 0.260085613 0.213710397 0.153916568] [0.251718313 0.25395745 0.254375577 0.210515186 0.152712509] [0.246533215 0.248635098 0.249027327 0.207459539 0.151536316] [0.241657034 0.243635193 0.244004101 0.204533577 0.15038693] [0.237060249 0.238926381 0.239274174 0.201728329 0.149263337] [0.232717097 0.234481394 0.234810054 0.199035719 0.148164615] [0.228605017 0.230276451 0.230587661 0.196448416 0.147089839] [0.224704206 0.226290658 0.22658591 0.193959698 0.14603813] [0.220997125 0.222505584 0.222786173 0.191563457 0.145008713] <tf.Tensor: shape=(5,), dtype=float32, numpy= array([0.21746822, 0.21890487, 0.21917202, 0.18925412, 0.14400077], dtype=float32)>
Se você estiver curioso, pode inspecionar o código gerado pelo autógrafo.
print(tf.autograph.to_code(f.python_function))
def tf__f(x): with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope: do_return = False retval_ = ag__.UndefinedReturnValue() def get_state(): return (x,) def set_state(vars_): nonlocal x (x,) = vars_ def loop_body(): nonlocal x ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope) x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope) def loop_test(): return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1) ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {}) try: do_return = True retval_ = ag__.ld(x) except: do_return = False raise return fscope.ret(retval_, do_return)
Condicionais
O AutoGraph converterá algumas instruções if <condition>
nas chamadas tf.cond
equivalentes. Esta substituição é feita se <condition>
for um tensor. Caso contrário, a instrução if
é executada como uma condicional do Python.
Uma condicional do Python é executada durante o rastreamento, portanto, exatamente uma ramificação da condicional será adicionada ao gráfico. Sem o AutoGraph, esse gráfico rastreado não conseguiria fazer a ramificação alternativa se houver fluxo de controle dependente de dados.
tf.cond
rastreia e adiciona ambas as ramificações da condicional ao gráfico, selecionando dinamicamente uma ramificação em tempo de execução. O rastreamento pode ter efeitos colaterais indesejados; confira os efeitos de rastreamento do AutoGraph para obter mais informações.
@tf.function
def fizzbuzz(n):
for i in tf.range(1, n + 1):
print('Tracing for loop')
if i % 15 == 0:
print('Tracing fizzbuzz branch')
tf.print('fizzbuzz')
elif i % 3 == 0:
print('Tracing fizz branch')
tf.print('fizz')
elif i % 5 == 0:
print('Tracing buzz branch')
tf.print('buzz')
else:
print('Tracing default branch')
tf.print(i)
fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop Tracing fizzbuzz branch Tracing fizz branch Tracing buzz branch Tracing default branch 1 2 fizz 4 buzz 1 2 fizz 4 buzz fizz 7 8 fizz buzz 11 fizz 13 14 fizzbuzz 16 17 fizz 19 buzz
Consulte a documentação de referência para restrições adicionais sobre instruções if convertidas pelo AutoGraph.
rotações
O AutoGraph converterá algumas instruções for
e while
nas operações de loop do TensorFlow equivalentes, como tf.while_loop
. Se não for convertido, o loop for
ou while
será executado como um loop Python.
Essa substituição é feita nas seguintes situações:
-
for x in y
: sey
for um tensor, converta paratf.while_loop
. No caso especial em quey
é umtf.data.Dataset
, uma combinação de operaçõestf.data.Dataset
é gerada. -
while <condition>
: se<condition>
for um tensor, converta paratf.while_loop
.
Um loop Python é executado durante o rastreamento, adicionando operações adicionais ao tf.Graph
para cada iteração do loop.
Um loop do TensorFlow rastreia o corpo do loop e seleciona dinamicamente quantas iterações devem ser executadas em tempo de execução. O corpo do loop aparece apenas uma vez no tf.Graph
gerado.
Consulte a documentação de referência para restrições adicionais sobre instruções for
e while
convertidas pelo AutoGraph.
Fazendo um loop sobre os dados do Python
Uma armadilha comum é fazer um loop sobre os dados Python/NumPy em um tf.function
. Este loop será executado durante o processo de rastreamento, adicionando uma cópia do seu modelo ao tf.Graph
para cada iteração do loop.
Se você quiser agrupar todo o loop de treinamento em tf.function
, a maneira mais segura de fazer isso é agrupar seus dados como um tf.data.Dataset
para que o AutoGraph desenrole dinamicamente o loop de treinamento.
def measure_graph_size(f, *args):
g = f.get_concrete_function(*args).graph
print("{}({}) contains {} nodes in its graph".format(
f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))
@tf.function
def train(dataset):
loss = tf.constant(0)
for x, y in dataset:
loss += tf.abs(y - x) # Some dummy computation.
return loss
small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph
Ao agrupar dados Python/NumPy em um conjunto de dados, lembre-se de tf.data.Dataset.from_generator
versus tf.data.Dataset.from_tensors
. O primeiro irá manter os dados em Python e buscá-los via tf.py_function
que pode ter implicações de desempenho, enquanto o último irá empacotar uma cópia dos dados como um grande nó tf.constant()
no gráfico, o que pode ter implicações de memória.
Ler dados de arquivos por meio de TFRecordDataset
, CsvDataset
, etc. é a maneira mais eficaz de consumir dados, pois o próprio TensorFlow pode gerenciar o carregamento assíncrono e a pré-busca de dados, sem precisar envolver o Python. Para saber mais, consulte o tf.data
: Criar pipelines de entrada do TensorFlow .
Acumulando valores em um loop
Um padrão comum é acumular valores intermediários de um loop. Normalmente, isso é feito anexando a uma lista Python ou adicionando entradas a um dicionário Python. No entanto, como esses são efeitos colaterais do Python, eles não funcionarão conforme o esperado em um loop desenrolado dinamicamente. Use tf.TensorArray
para acumular resultados de um loop desenrolado dinamicamente.
batch_size = 2
seq_len = 3
feature_size = 4
def rnn_step(inp, state):
return inp + state
@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
max_seq_len = input_data.shape[0]
states = tf.TensorArray(tf.float32, size=max_seq_len)
state = initial_state
for i in tf.range(max_seq_len):
state = rnn_step(input_data[i], state)
states = states.write(i, state)
return tf.transpose(states.stack(), [1, 0, 2])
dynamic_rnn(rnn_step,
tf.random.uniform([batch_size, seq_len, feature_size]),
tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy= array([[[0.06309307, 0.9938811 , 0.90789986, 0.42136216], [0.44997275, 1.9107027 , 1.0716251 , 0.717237 ], [0.6026064 , 2.1622117 , 1.4164022 , 1.4153863 ]], [[0.04946005, 0.69127274, 0.56848884, 0.22406638], [0.8148316 , 1.0278493 , 0.6207781 , 1.1935129 ], [0.9178308 , 1.320889 , 0.989761 , 2.0120025 ]]], dtype=float32)>
Limitações
A Function
TensorFlow tem algumas limitações por design que você deve estar ciente ao converter uma função Python em uma Function
.
Executando efeitos colaterais do Python
Os efeitos colaterais, como imprimir, anexar a listas e alterar globais, podem se comportar inesperadamente dentro de uma Function
, às vezes executando duas vezes ou não todas. Eles só acontecem na primeira vez que você chama uma Function
com um conjunto de entradas. Em seguida, o tf.Graph
rastreado é executado novamente, sem executar o código Python.
A regra geral é evitar depender dos efeitos colaterais do Python em sua lógica e usá-los apenas para depurar seus rastreamentos. Caso contrário, as APIs do TensorFlow como tf.data
, tf.print
, tf.summary
, tf.Variable.assign
e tf.TensorArray
são a melhor maneira de garantir que seu código seja executado pelo tempo de execução do TensorFlow a cada chamada.
@tf.function
def f(x):
print("Traced with", x)
tf.print("Executed with", x)
f(1)
f(1)
f(2)
Traced with 1 Executed with 1 Executed with 1 Traced with 2 Executed with 2
Se você deseja executar o código Python durante cada chamada de uma Function
, tf.py_function
é uma escotilha de saída. A desvantagem do tf.py_function
é que ele não é portátil ou particularmente performático, não pode ser salvo com SavedModel e não funciona bem em configurações distribuídas (multi-GPU, TPU). Além disso, como tf.py_function
precisa ser conectado ao gráfico, ele converte todas as entradas/saídas em tensores.
Alterando variáveis globais e livres do Python
Alterar as variáveis globais e livres do Python conta como um efeito colateral do Python, portanto, isso só acontece durante o rastreamento.
external_list = []
@tf.function
def side_effect(x):
print('Python side effect')
external_list.append(x)
side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect
Às vezes, comportamentos inesperados são muito difíceis de perceber. No exemplo abaixo, o counter
destina-se a salvaguardar o incremento de uma variável. No entanto, por ser um inteiro python e não um objeto TensorFlow, seu valor é capturado durante o primeiro rastreamento. Quando a tf.function
é usada, o assign_add
será registrado incondicionalmente no gráfico subjacente. Portanto, v
aumentará em 1, toda vez que a tf.function
. for chamada. Esse problema é comum entre os usuários que tentam migrar o código do Tensorflow no modo Grpah para o Tensorflow 2 usando decoradores tf.function
, quando os efeitos colaterais do python (o counter
no exemplo) são usados para determinar quais operações executar ( assign_add
no exemplo ). Normalmente, os usuários percebem isso somente depois de ver resultados numéricos suspeitos ou desempenho significativamente inferior ao esperado (por exemplo, se a operação protegida for muito cara).
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# A python side-effect
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 2, 3
1 2 3
Uma solução alternativa para atingir o comportamento esperado é usar tf.init_scope
para levantar as operações fora do gráfico da função. Isso garante que o incremento de variável seja feito apenas uma vez durante o tempo de rastreamento. Deve-se notar que init_scope
tem outros efeitos colaterais, incluindo fluxo de controle limpo e fita de gradiente. Às vezes, o uso de init_scope
pode se tornar muito complexo para ser gerenciado de forma realista.
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# Lifts ops out of function-building graphs
with tf.init_scope():
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 1, 1
1 1 1
Em resumo, como regra geral, você deve evitar a mutação de objetos python, como inteiros ou contêineres, como listas que vivem fora da Function
. Em vez disso, use argumentos e objetos TF. Por exemplo, a seção "Acumulando valores em um loop" tem um exemplo de como as operações do tipo lista podem ser implementadas.
Você pode, em alguns casos, capturar e manipular o estado se for um tf.Variable
. É assim que os pesos dos modelos Keras são atualizados com chamadas repetidas para o mesmo ConcreteFunction
.
Usando iteradores e geradores Python
Muitos recursos do Python, como geradores e iteradores, dependem do tempo de execução do Python para acompanhar o estado. Em geral, embora essas construções funcionem conforme o esperado no modo ansioso, elas são exemplos de efeitos colaterais do Python e, portanto, só acontecem durante o rastreamento.
@tf.function
def buggy_consume_next(iterator):
tf.print("Value:", next(iterator))
iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1 Value: 1 Value: 1
Assim como o TensorFlow tem um tf.TensorArray
especializado para construções de lista, ele tem um tf.data.Iterator
especializado para construções de iteração. Consulte a seção sobre transformações AutoGraph para obter uma visão geral. Além disso, a API tf.data
pode ajudar a implementar padrões geradores:
@tf.function
def good_consume_next(iterator):
# This is ok, iterator is a tf.data.Iterator
tf.print("Value:", next(iterator))
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1 Value: 2 Value: 3
Todas as saídas de uma função tf devem ser valores de retorno
Com exceção de tf.Variable
s, uma tf.function deve retornar todas as suas saídas. A tentativa de acessar diretamente qualquer tensor de uma função sem passar por valores de retorno causa "vazamentos".
Por exemplo, a função abaixo "vaza" o tensor a
através do x
global do Python:
x = None
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return a + 2
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
3 'Tensor' object has no attribute 'numpy'
Isso é verdade mesmo se o valor vazado também for retornado:
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return x # Good - uses local tensor
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
@tf.function
def captures_leaked_tensor(b):
b += x # Bad - `x` is leaked from `leaky_function`
return b
with assert_raises(TypeError):
captures_leaked_tensor(tf.constant(2))
2 'Tensor' object has no attribute 'numpy' Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/566849597.py", line 21, in <module> captures_leaked_tensor(tf.constant(2)) TypeError: Originated from a graph execution error. The graph execution error is detected at a node built at (most recent call last): >>> File /usr/lib/python3.7/runpy.py, line 193, in _run_module_as_main >>> File /usr/lib/python3.7/runpy.py, line 85, in _run_code >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel_launcher.py, line 16, in <module> >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/traitlets/config/application.py, line 846, in launch_instance >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelapp.py, line 677, in start >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tornado/platform/asyncio.py, line 199, in start >>> File /usr/lib/python3.7/asyncio/base_events.py, line 534, in run_forever >>> File /usr/lib/python3.7/asyncio/base_events.py, line 1771, in _run_once >>> File /usr/lib/python3.7/asyncio/events.py, line 88, in _run >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 457, in dispatch_queue >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 446, in process_one >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 353, in dispatch_shell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 648, in execute_request >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/ipkernel.py, line 353, in do_execute >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/zmqshell.py, line 533, in run_cell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2902, in run_cell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2947, in _run_cell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/async_helpers.py, line 68, in _pseudo_sync_runner >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3173, in run_cell_async >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3364, in run_ast_nodes >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3444, in run_code >>> File /tmp/ipykernel_26244/566849597.py, line 7, in <module> >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 910, in __call__ >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 958, in _call >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 781, in _initialize >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3157, in _get_concrete_function_internal_garbage_collected >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3557, in _maybe_define_function >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3402, in _create_graph_function >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1143, in func_graph_from_py_func >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 672, in wrapped_fn >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1125, in autograph_handler >>> File /tmp/ipykernel_26244/566849597.py, line 4, in leaky_function >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1383, in binary_op_wrapper >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py, line 1096, in op_dispatch_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1737, in _add_dispatch >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py, line 476, in add_v2 >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py, line 746, in _apply_op_helper >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 691, in _create_op_internal >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 3705, in _create_op_internal >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 2101, in __init__ Error detected in node 'add' defined at: File "/tmp/ipykernel_26244/566849597.py", line 4, in leaky_function TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.
Normalmente, vazamentos como esses ocorrem quando você usa instruções Python ou estruturas de dados. Além de vazar tensores inacessíveis, essas instruções também provavelmente estão erradas porque contam como efeitos colaterais do Python e não têm garantia de execução em todas as chamadas de função.
Maneiras comuns de vazar tensores locais também incluem a mutação de uma coleção externa do Python ou de um objeto:
class MyClass:
def __init__(self):
self.field = None
external_list = []
external_object = MyClass()
def leaky_function():
a = tf.constant(1)
external_list.append(a) # Bad - leaks tensor
external_object.field = a # Bad - leaks tensor
tf.functions recursivas não são suportadas
Function
recursivas s não são suportadas e podem causar loops infinitos. Por exemplo,
@tf.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1
with assert_raises(Exception):
recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Caught expected exception <class 'Exception'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/2233998312.py", line 9, in <module> recursive_fn(tf.constant(5)) # Bad - maximum recursion error. tensorflow.python.autograph.impl.api.StagingError: in user code: File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 3, in recursive_fn * if n > 0: File "/usr/lib/python3.7/abc.py", line 139, in __instancecheck__ return _abc_instancecheck(cls, instance) RecursionError: maximum recursion depth exceeded while calling a Python object
Mesmo que uma Function
recursiva pareça funcionar, a função python será rastreada várias vezes e pode ter implicações no desempenho. Por exemplo,
@tf.function
def recursive_fn(n):
if n > 0:
print('tracing')
return recursive_fn(n - 1)
else:
return 1
recursive_fn(5) # Warning - multiple tracings
tracing tracing tracing tracing tracing <tf.Tensor: shape=(), dtype=int32, numpy=1>
Problemas conhecidos
Se sua Function
não estiver avaliando corretamente, o erro pode ser explicado por esses problemas conhecidos que estão planejados para serem corrigidos no futuro.
Dependendo das variáveis globais e livres do Python
Function
cria uma nova ConcreteFunction
quando chamada com um novo valor de um argumento Python. No entanto, ele não faz isso para o encerramento do Python, globais ou não locais dessa Function
. Se o valor deles mudar entre as chamadas para a Function
, a Function
ainda usará os valores que tinha quando foi rastreada. Isso é diferente de como as funções regulares do Python funcionam.
Por esse motivo, você deve seguir um estilo de programação funcional que usa argumentos em vez de fechar sobre nomes externos.
@tf.function
def buggy_add():
return 1 + foo
@tf.function
def recommended_add(foo):
return 1 + foo
foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add()) # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100! Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(101, shape=(), dtype=int32)
Outra maneira de atualizar um valor global é torná-lo um tf.Variable
e usar o método Variable.assign
em vez disso.
@tf.function
def variable_add():
return 1 + foo
foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100! Variable: tf.Tensor(101, shape=(), dtype=int32)
Dependendo de objetos Python
A recomendação de passar objetos Python como argumentos para tf.function
tem vários problemas conhecidos, que devem ser corrigidos no futuro. Em geral, você pode confiar no rastreamento consistente se usar uma estrutura primitiva do Python ou compatível com tf.nest
como um argumento ou passar uma instância diferente de um objeto para uma Function
. No entanto, Function
não criará um novo trace quando você passar o mesmo objeto e apenas alterará seus atributos .
class SimpleModel(tf.Module):
def __init__(self):
# These values are *not* tf.Variables.
self.bias = 0.
self.weight = 2.
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x)) # Didn't change :(
Adding bias! tf.Tensor(20.0, shape=(), dtype=float32)
Usar a mesma Function
para avaliar a instância atualizada do modelo será problemático, pois o modelo atualizado tem a mesma chave de cache que o modelo original.
Por esse motivo, é recomendável escrever sua Function
para evitar depender de atributos de objetos mutáveis ou criar novos objetos.
Se isso não for possível, uma solução alternativa é criar novas Function
cada vez que você modificar seu objeto para forçar o retracing:
def evaluate(model, x):
return model.weight * x + model.bias
new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
Como o retracing pode ser caro , você pode usar tf.Variable
s como atributos do objeto, que podem ser modificados (mas não alterados, cuidado!) para um efeito semelhante sem precisar de retrace.
class BetterModel:
def __init__(self):
self.bias = tf.Variable(0.)
self.weight = tf.Variable(2.)
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5
print(evaluate(better_model, x)) # This works!
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
Criando tf.Variables
Function
suporta apenas tf.Variable
singleton criadas uma vez na primeira chamada e reutilizadas em chamadas de função subsequentes. O trecho de código abaixo criaria um novo tf.Variable
em cada chamada de função, o que resulta em uma exceção ValueError
.
Exemplo:
@tf.function
def f(x):
v = tf.Variable(1.0)
return v
with assert_raises(ValueError):
f(1.0)
Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/3018268426.py", line 7, in <module> f(1.0) ValueError: in user code: File "/tmp/ipykernel_26244/3018268426.py", line 3, in f * v = tf.Variable(1.0) ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
Um padrão comum usado para contornar essa limitação é começar com um valor Python None e, em seguida, criar condicionalmente a tf.Variable
se o valor for None:
class Count(tf.Module):
def __init__(self):
self.count = None
@tf.function
def __call__(self):
if self.count is None:
self.count = tf.Variable(0)
return self.count.assign_add(1)
c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32)
Usando com vários otimizadores Keras
Você pode encontrar ValueError: tf.function only supports singleton tf.Variables created on the first call.
ao usar mais de um otimizador Keras com um tf.function
. Esse erro ocorre porque os otimizadores criam internamente tf.Variables
quando aplicam gradientes pela primeira vez.
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
@tf.function
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
train_step(w, x, y, opt2)
Calling `train_step` with different optimizer... Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/3167358578.py", line 18, in <module> train_step(w, x, y, opt2) ValueError: in user code: File "/tmp/ipykernel_26244/3167358578.py", line 9, in train_step * optimizer.apply_gradients(zip(gradients, [w])) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients ** self._create_all_weights(var_list) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 828, in _create_all_weights _ = self.iterations File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 835, in __getattribute__ return super(OptimizerV2, self).__getattribute__(name) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 995, in iterations aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1202, in add_weight aggregation=aggregation) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_utils.py", line 129, in make_variable shape=variable_shape if variable_shape else None) ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
Se você precisar alterar o otimizador durante o treinamento, uma solução alternativa é criar uma nova Function
para cada otimizador, chamando a ConcreteFunction
diretamente.
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
# Not a tf.function.
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
if i % 2 == 0:
train_step_1(w, x, y) # `opt1` is not used as a parameter.
else:
train_step_2(w, x, y) # `opt2` is not used as a parameter.
Usando com vários modelos Keras
Você também pode encontrar ValueError: tf.function only supports singleton tf.Variables created on the first call.
ao passar diferentes instâncias de modelo para a mesma Function
.
Esse erro ocorre porque os modelos Keras (que não têm sua forma de entrada definida ) e as camadas Keras criam tf.Variables
s quando são chamadas pela primeira vez. Você pode estar tentando inicializar essas variáveis dentro de uma Function
, que já foi chamada. Para evitar esse erro, tente chamar model.build(input_shape)
para inicializar todos os pesos antes de treinar o modelo.
Leitura adicional
Para saber como exportar e carregar uma Function
, consulte o guia SavedModel . Para saber mais sobre otimizações de gráficos que são executadas após o rastreamento, consulte o guia do Grappler . Para saber como otimizar seu pipeline de dados e criar o perfil de seu modelo, consulte o guia do Profiler .