在 TensorFlow 2 中,Eager Execution 默认处于启用状态。界面非常灵活直观(执行一次性运算要简单快速得多),不过,这可能对性能和可部署性造成一定影响。
您可以使用 tf.function
将程序转换为计算图。这是一个转换工具,用于从 Python 代码创建独立于 Python 的数据流图。它可以帮助您创建高效且可移植的模型,并且如果要使用 SavedModel
,则必须使用此工具。
本指南介绍 tf.function
的底层工作原理,让您形成概念化理解,从而有效地加以利用。
要点和建议包括:
- 先在 Eager 模式下调试,然后使用
@tf.function
进行装饰。 - 不依赖 Python 的副作用,如对象变异或列表追加。
tf.function
最适合处理 TensorFlow 运算;NumPy 和 Python 调用会转换为常量。
设置
# Update TensorFlow, as this notebook requires version 2.9 or later
!pip install -q -U tensorflow>=2.9.0
import tensorflow as tf
2022-12-14 22:33:48.348405: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 22:33:48.348501: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 22:33:48.348510: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
定义一个辅助函数来演示可能遇到的错误类型:
import traceback
import contextlib
# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
try:
yield
except error_class as e:
print('Caught expected exception \n {}:'.format(error_class))
traceback.print_exc(limit=2)
except Exception as e:
raise e
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))
基础知识
用法
您定义的 Function
(例如,通过应用 @tf.function
装饰器)就像核心 TensorFlow 运算:您可以在 Eager 模式下执行它,可以计算梯度,等等。
@tf.function # The decorator converts `add` into a `Function`.
def add(a, b):
return a + b
add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy= array([[2., 2.], [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
Function
中可以嵌套其他 Function
。
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy= array([[3., 3.], [3., 3.], [3., 3.]], dtype=float32)>
Function
的执行速度比 Eager 代码快,尤其是对于包含很多简单运算的计算图。但是,对于包含一些复杂运算(如卷积)的计算图,速度提升不会太明显。
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)
@tf.function
def conv_fn(image):
return conv_layer(image)
image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006609583000681596 Function conv: 0.006563433998962864 Note how there's not much difference in performance for convolutions
跟踪
本部分介绍了 Function
的幕后运作方式,包括未来可能会发生变化的实现细节。但是,当您了解跟踪的原因和时间后,就能够更轻松高效地使用 tf.function
!
什么是“跟踪”?
Function
在 TensorFlow 计算图中运行您的程序。但是,tf.Graph
不能代表您在 Eager TensorFlow 程序中编写的全部内容。例如,Python 支持多态,但是 tf.Graph
要求其输入具有指定的数据类型和维度。或者,您可能执行辅助任务,例如读取命令行参数、引发错误或使用更复杂的 Python 对象。这些内容均不能在 tf.Graph
中运行。
Function
通过将代码分为以下两个阶段填补了这一空缺:
第一阶段称为跟踪,在这一阶段中,
Function
会创建新的tf.Graph
。Python 代码可以正常运行,但是所有 TensorFlow 运算(例如添加两个张量)都会被推迟:它们会被tf.Graph
捕获而不运行。在第二阶段中,将运行包含第一阶段中推迟的全部内容的
tf.Graph
。此阶段比跟踪阶段快得多。
根据输入,Function
在调用时并非总会运行第一阶段。请参阅下方的跟踪规则以更好地了解其决定方式。跳过第一阶段并仅执行第二阶段,可以实现 TensorFlow 的高性能。
当 Function
决定跟踪时,在跟踪阶段完成后会立即运行第二阶段,因此调用 Function
会创建并运行 tf.Graph
。稍后,您将了解如何使用 get_concrete_function
来仅运行跟踪阶段。
当您将不同类型的参数传递给 Function
时,两个阶段都将运行:
@tf.function
def double(a):
print("Tracing with", a)
return a + a
print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32) Tracing with Tensor("a:0", shape=(), dtype=float32) tf.Tensor(2.2, shape=(), dtype=float32) Tracing with Tensor("a:0", shape=(), dtype=string) tf.Tensor(b'aa', shape=(), dtype=string)
请注意,如果重复使用同一参数类型调用 Function
,TensorFlow 会跳过跟踪阶段并重用之前跟踪的计算图,因为后面的调用生成的计算图可能相同。
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)
您可以使用 pretty_printed_concrete_signatures()
查看所有可用跟踪记录:
print(double.pretty_printed_concrete_signatures())
double(a) Args: a: int32 Tensor, shape=() Returns: int32 Tensor, shape=() double(a) Args: a: float32 Tensor, shape=() Returns: float32 Tensor, shape=() double(a) Args: a: string Tensor, shape=() Returns: string Tensor, shape=()
目前,您已经了解 tf.function
通过 TensorFlow 的计算图跟踪逻辑创建缓存的动态调度层。对于术语的含义,更具体的解释如下:
tf.Graph
与语言无关,是 TensorFlow 计算的原始可移植表示。ConcreteFunction
封装tf.Graph
。Function
管理ConcreteFunction
的缓存,并为输入选择正确的缓存。tf.function
包装 Python 函数,并返回一个Function
对象。- 跟踪会创建
tf.Graph
并将其封装在ConcreteFunction
中,也称为跟踪。
跟踪规则
被调用时,Function
使用每个参数的 tf.types.experimental.TraceType
将调用参数与现有的 ConcreteFunction
匹配。如果找到匹配的 ConcreteFunction
,则将调用分派给它。如果未找到匹配项,则跟踪新的 ConcreteFunction
。
如果找到多个匹配项,则会选择最具体的签名。匹配是通过子类型化完成的,就像 C++ 或 Java 中的普通函数调用一样。例如,TensorShape([1, 2])
是 TensorShape([None, None])
的子类型,因此可以将使用 TensorShape([1, 2])
对 tf.function 进行的调用分派到使用 TensorShape([None, None])
生成的 ConcreteFunction
。但是,如果具有 TensorShape([1, None])
的 ConcreteFunction
也存在,那么它将被优先考虑,因为它更具体。
TraceType
由输入参数确定,具体如下所示:
- 对于
Tensor
,类型由Tensor
的dtype
和shape
参数化;有秩形状是无秩形状的子类型;固定维度是未知维度的子类型 - 对于
Variable
,类型类似于Tensor
,但还包括变量的唯一资源 ID,这是正确连接控制依赖项所必需的 - 对于 Python 基元值,类型对应于值本身。例如,值为
3
的TraceType
是LiteralTraceType<3>
,而不是int
。 - 对于
list
和tuple
等 Python 有序容器,类型是通过其元素的类型来参数化的;例如,[1, 2]
的类型是ListTraceType<LiteralTraceType<1>, LiteralTraceType<2>>
,[2, 1]
的类型是ListTraceType<LiteralTraceType<2>, LiteralTraceType<1>>
,两者不同。 - 对于
dict
等 Python 映射,类型也是从相同的键到值类型而不是实际值的映射。例如,{1: 2, 3: 4}
的类型为MappingTraceType<<KeyValue<1, LiteralTraceType<2>>>, <KeyValue<3, LiteralTraceType<4>>>>
。但是,与有序容器不同的是,{1: 2, 3: 4}
和{3: 4, 1: 2}
具有等价的类型。 - 对于实现
__tf_tracing_type__
方法的 Python 对象,类型为该方法返回的任何内容 - 对于任何其他 Python 对象,类型是通用的
TraceType
,它使用对象的 Python 相等性和散列进行匹配。(注:它依赖于对对象的弱引用,因此仅在对象处于范围内/未被删除时才有效。)
注:TraceType
基于 Function
输入参数,因此仅对全局变量和自由变量进行更改将不会创建新的跟踪记录。有关处理 Python 全局变量和自由变量的建议做法,请参阅本部分。
控制回溯
回溯即 Function
创建多个跟踪记录的过程,可以确保 TensorFlow 为每组输入生成正确的计算图。但是,跟踪非常消耗资源!如果 Function
为每一次调用都回溯新的计算图,您会发现代码的执行速度远不如不使用 tf.function
时快。
要控制跟踪行为,可以采用以下技巧:
将固定的 input_signature
传递给 tf.function
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(x % 2 == 0, x // 2, 3 * x + 1)
print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([[1, 2], [3, 4]]))
# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([4 1], shape=(2,), dtype=int32) Caught expected exception <class 'ValueError'>: Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_176735/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_176735/1851403433.py", line 9, in <module> next_collatz(tf.constant([[1, 2], [3, 4]])) ValueError: Python inputs incompatible with input_signature: inputs: ( tf.Tensor( [[1 2] [3 4]], shape=(2, 2), dtype=int32)) input_signature: ( TensorSpec(shape=(None,), dtype=tf.int32, name=None)). Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_176735/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_176735/1851403433.py", line 13, in <module> next_collatz(tf.constant([1.0, 2.0])) ValueError: Python inputs incompatible with input_signature: inputs: ( tf.Tensor([1. 2.], shape=(2,), dtype=float32)) input_signature: ( TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
使用未知维度以获得灵活性
由于 TensorFlow 根据其形状匹配张量,因此,对于可变大小输入,使用 None
维度作为通配符可以让 Function
重复使用跟踪记录。对于每个批次,如果有不同长度的序列或不同大小的图像,则会出现可变大小输入(请参阅 Transformer 和 Deep Dream 教程了解示例)。
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
print('Tracing with', x)
return x
# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([1 2 3], shape=(3,), dtype=int32) tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
传递张量而不是 Python 文字
通常,Python 参数用于控制超参数和计算图构造,例如 num_layers=10
、training=True
或 nonlinearity='relu'
。所以,如果 Python 参数改变,则有必要回溯计算图。
但是,Python 参数有可能并未用于控制计算图构造。在这些情况下,Python 值的改变可能触发非必要的回溯。例如,在此训练循环中,AutoGraph 会动态展开。尽管有多个跟踪,但生成的计算图实际上是相同的,所以没有必要进行回溯。
def train_one_step():
pass
@tf.function
def train(num_steps):
print("Tracing with num_steps = ", num_steps)
tf.print("Executing with num_steps = ", num_steps)
for _ in tf.range(num_steps):
train_one_step()
print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)
print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments. Tracing with num_steps = 10 Executing with num_steps = 10 Tracing with num_steps = 20 Executing with num_steps = 20 Traces are reused for Tensor arguments. Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32) Executing with num_steps = 10 Executing with num_steps = 20
如果需要强制执行回溯,可以创建一个新的 Function
。单独的 Function
对象肯定不会共享跟踪记录。
def f():
print('Tracing!')
tf.print('Executing')
tf.function(f)()
tf.function(f)()
Tracing! Executing Tracing! Executing
使用跟踪协议
在可能的情况下,您应当首选将 Python 类型转换为 tf.experimental.ExtensionType
。此外,ExtensionType
的 TraceType
是与其关联的 tf.TypeSpec
。因此,如果需要,您只需重写默认的 tf.TypeSpec
即可控制 ExtensionType
的 Tracing Protocol
。请参阅扩展程序类型指南中的自定义 ExtensionType 的 TypeSpec部分以了解详情。
否则,要直接控制 Function
何时应针对特定 Python 类型进行重新跟踪,您可以自行为其实现 Tracing Protocol
。
@tf.function
def get_mixed_flavor(fruit_a, fruit_b):
return fruit_a.flavor + fruit_b.flavor
class Fruit:
flavor = tf.constant([0, 0])
class Apple(Fruit):
flavor = tf.constant([1, 2])
class Mango(Fruit):
flavor = tf.constant([3, 4])
# As described in the above rules, a generic TraceType for `Apple` and `Mango`
# is generated (and a corresponding ConcreteFunction is traced) but it fails to
# match the second function call since the first pair of Apple() and Mango()
# have gone out out of scope by then and deleted.
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again
# However, each subclass of the `Fruit` class has a fixed flavor, and you
# can reuse an existing traced concrete function if it was the same
# subclass. Avoiding such unnecessary tracing of concrete functions
# can have significant performance benefits.
class FruitTraceType(tf.types.experimental.TraceType):
def __init__(self, fruit_type):
self.fruit_type = fruit_type
def is_subtype_of(self, other):
return (type(other) is FruitTraceType and
self.fruit_type is other.fruit_type)
def most_specific_common_supertype(self, others):
return self if all(self == other for other in others) else None
def __eq__(self, other):
return type(other) is FruitTraceType and self.fruit_type == other.fruit_type
def __hash__(self):
return hash(self.fruit_type)
class FruitWithTraceType:
def __tf_tracing_type__(self, context):
return FruitTraceType(type(self))
class AppleWithTraceType(FruitWithTraceType):
flavor = tf.constant([1, 2])
class MangoWithTraceType(FruitWithTraceType):
flavor = tf.constant([3, 4])
# Now if you try calling it again:
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Traces a new concrete function
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Re-uses the traced concrete function
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 6], dtype=int32)>
获取具体函数
每次跟踪函数时都会创建一个新的具体函数。您可以使用 get_concrete_function
直接获取具体函数。
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace Executing traced function tf.Tensor(b'aa', shape=(), dtype=string) tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)
打印 ConcreteFunction
会显示其输入参数(及类型)和输出类型的摘要。
print(double_strings)
ConcreteFunction double(a) Args: a: string Tensor, shape=() Returns: string Tensor, shape=()
您也可以直接检索具体函数的签名。
print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {}) Tensor("Identity:0", shape=(), dtype=string)
对不兼容的类型使用具体跟踪会引发错误
with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))
Caught expected exception <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_176735/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_176735/3196284684.py", line 2, in <module> double_strings(tf.constant(1)) tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_166 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_166]
您可能会注意到,在具体函数的输入签名中对 Python 参数进行了特别处理。TensorFlow 2.3 之前的版本会将 Python 参数直接从具体函数的签名中删除。从 TensorFlow 2.3 开始,Python 参数会保留在签名中,但是会受到约束,只能获取在跟踪期间设置的值。
@tf.function
def pow(a, b):
return a ** b
square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2) Args: a: float32 Tensor, shape=<unknown> Returns: float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100
with assert_raises(TypeError):
square(tf.constant(10.0), b=3)
Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/monomorphic_function.py", line 1487, in _call_impl return self._call_with_flat_signature(args, kwargs, File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/monomorphic_function.py", line 1532, in _call_with_flat_signature raise TypeError(f"{self._flat_signature_summary()} got unexpected " TypeError: pow(a) got unexpected keyword arguments: b. During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_176735/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_176735/2310937119.py", line 4, in <module> square(tf.constant(10.0), b=3) TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.
获取计算图
每个具体函数都是 tf.Graph
的可调用包装器。虽然一般不需要检索实际 tf.Graph
对象,不过,您可以从任何具体函数轻松获得实际对象。
graph = double_strings.graph
for node in graph.as_graph_def().node:
print(f'{node.input} -> {node.name}')
[] -> a ['a', 'a'] -> add ['add'] -> Identity
调试
通常,在 Eager 模式下调试代码比在 tf.function
中简单。在使用 tf.function
进行装饰之前,进行装饰之前,您应该先确保代码可在 Eager 模式下无错误执行。为了帮助调试,您可以调用 tf.config.run_functions_eagerly(True)
来全局停用和重新启用 tf.function
。
追溯仅在 tf.function
中出现的问题时,可参考下面的几点提示:
- 普通旧 Python
print
调用仅在跟踪期间执行,可以帮助您在(重新)跟踪函数时进行追溯。 tf.print
调用每次都会执行,可用于追溯执行过程中产生的中间值。- 利用
tf.debugging.enable_check_numerics
很容易追溯到 NaN 和 Inf 在何处创建。 pdb
(Python 调试器)可以帮助您理解跟踪的详细过程。(提醒:使用pdb
调试时,AutoGraph 会自动转换 Python 源代码。)
AutoGraph 转换
AutoGraph 是一个库,在 tf.function
中默认处于启用状态。它可以将 Python Eager 代码的子集转换为与计算图兼容的 TensorFlow 运算。这包括 if
、for
、while
等控制流。
tf.cond
和 tf.while_loop
等 TensorFlow 运算仍然可以运行,但是使用 Python 编写时,控制流通常更易于编写,代码也更易于理解。
# A simple loop
@tf.function
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
f(tf.random.uniform([5]))
[0.682211161 0.396621943 0.451262951 0.643357158 0.87304759] [0.592955 0.37705484 0.422936589 0.567181051 0.702919185] [0.532017589 0.360147029 0.399401426 0.513286 0.606217444] [0.486921817 0.3453435 0.379436702 0.472501546 0.541458964] [0.451769888 0.332239449 0.362218171 0.4402183 0.49409157] [0.423352748 0.320531547 0.347166359 0.413825333 0.45745784] [0.399751157 0.309987456 0.333860159 0.391715884 0.428010017] [0.379735976 0.300425678 0.321985 0.372838497 0.40365687] [0.362478137 0.291702092 0.311300635 0.356472 0.383073539] [0.347394943 0.283700645 0.301619858 0.342102677 0.365373671] [0.334063202 0.276326627 0.292794317 0.329353303 0.349938482] [0.322167 0.269501895 0.284704626 0.31793955 0.336320966] [0.311465025 0.263161272 0.277253687 0.307642668 0.324188948] [0.301769286 0.257249981 0.270361394 0.298291 0.313289642] [0.292930901 0.251721531 0.263961077 0.289747804 0.303426832] [0.284830183 0.24653624 0.257996708 0.281902701 0.294445485] [0.277369589 0.241659909 0.252420813 0.274665147 0.286221236] [0.270468801 0.237062961 0.247192904 0.26796037 0.278653115] [0.264061 0.23271966 0.242278129 0.261725903 0.271658033] [0.25809 0.228607446 0.237646371 0.255909115 0.265166938] [0.252508163 0.224706501 0.23327139 0.250465214 0.259121925] [0.24727492 0.22099933 0.229130313 0.245355889 0.253474057] [0.242355347 0.217470333 0.225202918 0.240548164 0.248181522] [0.237719223 0.214105651 0.221471429 0.236013427 0.243208483] [0.233340293 0.210892901 0.21792005 0.231726721 0.23852399] [0.229195595 0.207821 0.214534715 0.227666199 0.234101087] [0.225264892 0.20487988 0.211302832 0.22381258 0.229916304] [0.221530348 0.202060521 0.208213195 0.220148876 0.225948915] [0.217976168 0.199354753 0.205255613 0.216659933 0.222180739] [0.214588255 0.196755111 0.20242089 0.213332266 0.218595564] [0.211354017 0.19425483 0.199700788 0.210153803 0.215179041] [0.208262146 0.191847727 0.197087735 0.207113713 0.211918324] [0.205302477 0.189528167 0.194574893 0.204202175 0.20880191] <tf.Tensor: shape=(5,), dtype=float32, numpy= array([0.20246583, 0.18729094, 0.19215602, 0.2014104 , 0.20581943], dtype=float32)>
如果您有兴趣,可以检查 Autograph 生成的代码。
print(tf.autograph.to_code(f.python_function))
def tf__f(x): with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope: do_return = False retval_ = ag__.UndefinedReturnValue() def get_state(): return (x,) def set_state(vars_): nonlocal x (x,) = vars_ def loop_body(): nonlocal x ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope) x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope) def loop_test(): return ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1 ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {}) try: do_return = True retval_ = ag__.ld(x) except: do_return = False raise return fscope.ret(retval_, do_return)
条件语句
AutoGraph 会将某些 if <condition>
语句转换为等效的 tf.cond
调用。如果 <condition>
是张量,则会执行这种替换,否则会将 if
语句作为 Python 条件语句执行。
Python 条件语句在跟踪时执行,因此会将该条件语句的一个分支添加到计算图。如果不使用 AutoGraph,当存在依赖于数据的控制流时,此跟踪计算图将无法选择替代分支。
tf.cond
跟踪并将条件的两个分支添加到计算图,在执行时动态选择分支。跟踪可能产生意外的副作用;请参阅 AutoGraph 跟踪作用以了解详情。
@tf.function
def fizzbuzz(n):
for i in tf.range(1, n + 1):
print('Tracing for loop')
if i % 15 == 0:
print('Tracing fizzbuzz branch')
tf.print('fizzbuzz')
elif i % 3 == 0:
print('Tracing fizz branch')
tf.print('fizz')
elif i % 5 == 0:
print('Tracing buzz branch')
tf.print('buzz')
else:
print('Tracing default branch')
tf.print(i)
fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop Tracing fizzbuzz branch Tracing fizz branch Tracing buzz branch Tracing default branch 1 2 fizz 4 buzz 1 2 fizz 4 buzz fizz 7 8 fizz buzz 11 fizz 13 14 fizzbuzz 16 17 fizz 19 buzz
有关 AutoGraph 转换的 if 语句的其他限制,请参阅参考文档。
循环
AutoGraph 会将某些 for
和 while
语句转换为等效的 TensorFlow 循环运算,例如 tf.while_loop
。如果不转换,则会将 for
或 while
循环作为 Python 循环执行。
以下情形会执行这种替换:
for x in y
:如果y
是一个张量,则转换为tf.while_loop
。在特殊情况下,如果y
是tf.data.Dataset
,则会生成tf.data.Dataset
运算的组合。while <condition>
:如果<condition>
是张量,则转换为tf.while_loop
。
Python 循环在跟踪时执行,因而循环每迭代一次,都会将额外的运算添加到 tf.Graph
。
TensorFlow 循环会跟踪循环体,并在执行时动态选择迭代的运行次数。循环体仅在生成的 tf.Graph
中出现一次。
有关 AutoGraph 转换的 for
和 while
语句的其他限制,请参阅参考文档。
在 Python 数据上循环
一个常见陷阱是在 tf.function
中的 Python/Numpy 数据上循环。此循环在跟踪过程中执行,因而循环每迭代一次,都会将模型的一个副本添加到 tf.Graph
。
如果要在 tf.function
中包装整个训练循环,最安全的方法是将数据包装为 tf.data.Dataset
,以便 AutoGraph 动态展开训练循环。
def measure_graph_size(f, *args):
g = f.get_concrete_function(*args).graph
print("{}({}) contains {} nodes in its graph".format(
f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))
@tf.function
def train(dataset):
loss = tf.constant(0)
for x, y in dataset:
loss += tf.abs(y - x) # Some dummy computation.
return loss
small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph train(<FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph train(<FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph
在数据集中封装 Python/Numpy 数据时,要注意 tf.data.Dataset.from_generator
与 tf.data.Dataset.from_tensors
。前者将数据保留在 Python 中,并通过 tf.py_function
获取,这可能会影响性能;后者将数据的副本捆绑成计算图中的一个大 tf.constant()
节点,这可能会消耗较多内存。
通过 TFRecordDataset
、CsvDataset
等从文件中读取数据是最高效的数据使用方式,因为这样 TensorFlow 就可以自行管理数据的异步加载和预提取,不必利用 Python。要了解详细信息,请参阅 tf.data
:构建 TensorFlow 输入流水线指南。
累加循环值
一种常见模式是不断累加循环的中间值。通常,这可以通过将元素追加到 Python 列表或将条目添加到 Python 字典来实现。但是,由于存在 Python 副作用,在动态展开循环中,这些方法无法达到预期效果。要从动态展开循环累加结果,可以使用 tf.TensorArray
来实现。
batch_size = 2
seq_len = 3
feature_size = 4
def rnn_step(inp, state):
return inp + state
@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
max_seq_len = input_data.shape[0]
states = tf.TensorArray(tf.float32, size=max_seq_len)
state = initial_state
for i in tf.range(max_seq_len):
state = rnn_step(input_data[i], state)
states = states.write(i, state)
return tf.transpose(states.stack(), [1, 0, 2])
dynamic_rnn(rnn_step,
tf.random.uniform([batch_size, seq_len, feature_size]),
tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy= array([[[0.02787352, 0.82700455, 0.9538677 , 0.9223 ], [0.9766793 , 1.3111272 , 0.9800459 , 1.8446014 ], [1.1777725 , 1.7788446 , 1.8246934 , 2.3526337 ]], [[0.07594228, 0.07718706, 0.08466768, 0.59425294], [0.47839463, 0.5704447 , 0.64917386, 0.84621584], [1.254829 , 1.3224411 , 1.6046455 , 1.6987513 ]]], dtype=float32)>
限制
TensorFlow Function
有意设计了一些限制,在将 Python 函数转换为 Function
时需加以注意。
执行 Python 副作用
副作用(如打印、附加到列表、改变全局变量)在 Function
内部可能会出现异常行为,有时会执行两次或完全无法执行。它们只会在您第一次使用一组输入调用 Function
时发生。之后,将重新执行跟踪的 tf.Graph
,而不执行 Python 代码。
一般经验法则是避免在逻辑中依赖 Python 副作用,而仅使用它们来调试跟踪记录。否则,TensorFlow API(例如 tf.data
、tf.print
、tf.summary
、tf.Variable.assign
和 tf.TensorArray
)是确保在每次调用时 TensorFlow 运行时都能执行您的代码的最佳方式。
@tf.function
def f(x):
print("Traced with", x)
tf.print("Executed with", x)
f(1)
f(1)
f(2)
Traced with 1 Executed with 1 Executed with 1 Traced with 2 Executed with 2
如果希望在每次调用 Function
时都执行 Python 代码,tf.py_function
可以作为退出点。tf.py_function
的缺点是不可移植,性能不高,无法使用 SavedModel 保存并且在分布式(多 GPU、TPU)设置中效果不佳。另外,由于 tf.py_function
必须连接到计算图中,它会将所有输入/输出转换为张量。
更改 Python 全局变量和自由变量
更改 Python 全局变量和自由变量视为 Python 副作用,因此仅在跟踪期间发生。
external_list = []
@tf.function
def side_effect(x):
print('Python side effect')
external_list.append(x)
side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect
有时很难注意到意外行为。在下面的示例中,counter
旨在保护变量的增量。然而,由于它是一个 Python 整数而不是 TensorFlow 对象,它的值在第一次跟踪期间被捕获。使用 tf.function
时,assign_add
将被无条件记录在底层计算图中。因此,每次调用 tf.function
时 v
都会增加 1。当使用 Python 副作用(示例中的 counter
)确定要运行的运算(示例中的 assign_add
)时,此问题在尝试使用 tf.function
装饰器将其计算图模式 Tensorflow 代码迁移到 Tensorflow 2 的用户中十分常见。通常,用户只有在看到可疑的数值结果或明显低于预期的性能(例如,如果受保护运算的开销非常大)后才会意识到这一点。
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# A python side-effect
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 2, 3
1 2 3
实现预期行为的一种解决方法是使用 tf.init_scope
将运算提升到函数计算图以外。这样可以确保变量增量在跟踪期间只执行一次。应当注意的是,init_scope
还有其他副作用,包括清除控制流和梯度带。有时 init_scope
的使用会变得过于复杂而无法实际管理。
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# Lifts ops out of function-building graphs
with tf.init_scope():
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 1, 1
1 1 1
总之,根据经验,您应避免改变整数或容器(如位于 Function
外部的列表)等 Python 对象,而应使用参数和 TF 对象。例如,在循环中累加值部分中提供了一个如何实现类列表运算的示例。
在某些情况下,如果为 tf.Variable
,则您可以捕获和处理状态。这是通过重复调用相同的 ConcreteFunction
来更新 Keras 模型权重的方式。
使用 Python 迭代器和生成器
很多 Python 功能(如生成器和迭代器)依赖 Python 运行时来跟踪状态。通常,虽然这些构造在 Eager 模式下可以正常工作,但它们是 Python 副作用的示例,因此仅在跟踪期间发生。
@tf.function
def buggy_consume_next(iterator):
tf.print("Value:", next(iterator))
iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1 Value: 1 Value: 1
就像 TensorFlow 具有用于列表构造的专用 tf.TensorArray
一样,它也具有用于迭代构造的专用 tf.data.Iterator
。有关概述,请参阅 AutoGraph 转换部分。此外,tf.data
API 也可帮助实现生成器模式:
@tf.function
def good_consume_next(iterator):
# This is ok, iterator is a tf.data.Iterator
tf.print("Value:", next(iterator))
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1 Value: 2 Value: 3
tf.function 的所有输出都必须是返回值
除了 tf.Variable
外,一个 tf.function 必须返回其所有输出。尝试直接从函数访问任何张量而不遍历返回值会导致“泄漏”。
例如,下面的函数通过 Python 全局变量 x
“泄漏”张量 a
:
x = None
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return a + 2
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
3 'Tensor' object has no attribute 'numpy'
即使同时返回泄漏的值时也是如此:
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return x # Good - uses local tensor
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
@tf.function
def captures_leaked_tensor(b):
b += x # Bad - `x` is leaked from `leaky_function`
return b
with assert_raises(TypeError):
captures_leaked_tensor(tf.constant(2))
2 'Tensor' object has no attribute 'numpy' Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_176735/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_176735/566849597.py", line 21, in <module> captures_leaked_tensor(tf.constant(2)) TypeError: <tf.Tensor 'add:0' shape=() dtype=int32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information. <tf.Tensor 'add:0' shape=() dtype=int32> was defined here: File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module> app.launch_new_instance() File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/traitlets/config/application.py", line 992, in launch_instance app.start() File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 711, in start self.io_loop.start() File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 215, in start self.asyncio_loop.run_forever() File "/usr/lib/python3.9/asyncio/base_events.py", line 601, in run_forever self._run_once() File "/usr/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once handle._run() File "/usr/lib/python3.9/asyncio/events.py", line 80, in _run self._context.run(self._callback, *self._args) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue await self.process_one() File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 499, in process_one await dispatch(*args) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell await result File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 729, in execute_request reply_content = await reply_content File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 411, in do_execute res = shell.run_cell( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 531, in run_cell return super().run_cell(*args, **kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2940, in run_cell result = self._run_cell( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2995, in _run_cell return runner(coro) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner coro.send(None) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3194, in run_cell_async has_raised = await self.run_ast_nodes(code_ast.body, cell_name, File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3373, in run_ast_nodes if await self.run_code(code, result, async_=asy): File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "/tmpfs/tmp/ipykernel_176735/566849597.py", line 7, in <module> correct_a = leaky_function(tf.constant(1)) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler return fn(*args, **kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 880, in __call__ result = self._call(*args, **kwds) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 928, in _call self._initialize(args, kwds, add_initializers_to=initializers) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 749, in _initialize self._variable_creation_fn # pylint: disable=protected-access File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 162, in _get_concrete_function_internal_garbage_collected concrete_function, _ = self._maybe_define_concrete_function(args, kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 157, in _maybe_define_concrete_function return self._maybe_define_function(args, kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 360, in _maybe_define_function concrete_function = self._create_concrete_function(args, kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 284, in _create_concrete_function func_graph_module.func_graph_from_py_func( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1283, in func_graph_from_py_func func_outputs = python_func(*func_args, **func_kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 645, in wrapped_fn out = weak_wrapped_fn().__wrapped__(*args, **kwds) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1258, in autograph_handler return autograph.converted_call( File "/tmpfs/tmp/ipykernel_176735/566849597.py", line 4, in leaky_function x = a + 1 # Bad - leaks local tensor File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler return fn(*args, **kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1407, in binary_op_wrapper return func(x, y, name=name) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler return fn(*args, **kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py", line 1176, in op_dispatch_handler return dispatch_target(*args, **kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1757, in _add_dispatch return gen_math_ops.add_v2(x, y, name=name) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py", line 475, in add_v2 _, _, _op, _outputs = _op_def_library._apply_op_helper( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py", line 795, in _apply_op_helper op = g._create_op_internal(op_type_name, inputs, dtypes=None, File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 749, in _create_op_internal return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 3798, in _create_op_internal ret = Operation( The tensor <tf.Tensor 'add:0' shape=() dtype=int32> cannot be accessed from here, because it was defined in FuncGraph(name=leaky_function, id=139917558711056), which is out of scope.
通常,当您使用 Python 语句或数据结构时,会发生此类泄漏。除了泄漏不可访问的张量之外,此类语句也可能是错误的,因为它们被视为 Python 副作用,而且不能保证在每次函数调用时都执行。
泄漏局部张量的常见方法还包括改变外部 Python 集合或对象:
class MyClass:
def __init__(self):
self.field = None
external_list = []
external_object = MyClass()
def leaky_function():
a = tf.constant(1)
external_list.append(a) # Bad - leaks tensor
external_object.field = a # Bad - leaks tensor
不支持递归 tf.functions
不支持递归 Function
,它们可能导致无限循环。例如:
@tf.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1
with assert_raises(Exception):
recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Caught expected exception <class 'Exception'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_176735/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 9, in <module> recursive_fn(tf.constant(5)) # Bad - maximum recursion error. tensorflow.python.autograph.impl.api.StagingError: in user code: File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_176735/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/usr/lib/python3.9/abc.py", line 119, in __instancecheck__ return _abc_instancecheck(cls, instance) File "/usr/lib/python3.9/abc.py", line 123, in __subclasscheck__ return _abc_subclasscheck(cls, subclass) RecursionError: maximum recursion depth exceeded while calling a Python object
即使递归 Function
看似有效,Python 函数也会被多次跟踪,并且可能会对性能产生影响。例如:
@tf.function
def recursive_fn(n):
if n > 0:
print('tracing')
return recursive_fn(n - 1)
else:
return 1
recursive_fn(5) # Warning - multiple tracings
tracing tracing tracing tracing tracing <tf.Tensor: shape=(), dtype=int32, numpy=1>
已知问题
如果您的 Function
评估不正确,则这些计划于将来得到修复的已知问题可能可以解释该问题。
取决于 Python 全局变量和自由变量
当使用 Python 参数的新值进行调用时,Function
会创建新的 ConcreteFunction
。但是,对于该 Function
的 Python 闭包、全局变量或非局部变量,则不会创建。如果它们的值在调用 Function
之间发生变化,则 Function
仍将使用其在跟踪时所具有的值。这与常规 Python 函数的工作方式不同。
因此,您应采用使用参数的函数式编程风格而非闭合外部名称。
@tf.function
def buggy_add():
return 1 + foo
@tf.function
def recommended_add(foo):
return 1 + foo
foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add()) # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100! Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(101, shape=(), dtype=int32)
更新全局值的另一种方法是使其成为 tf.Variable
并改用 Variable.assign
方法。
@tf.function
def variable_add():
return 1 + foo
foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100! Variable: tf.Tensor(101, shape=(), dtype=int32)
取决于 Python 对象
将 Python 对象作为参数传递给 tf.function
的建议存在许多已知问题,预计会在以后得到解决。通常,如果您使用 Python 基元或兼容 tf.nest
的结构作为参数,或将对象的不同实例传递给 Function
,则可以依赖稳定的跟踪。但是,如果您传递同一对象并仅更改其特性时,Function
将不会创建新的跟踪记录。
class SimpleModel(tf.Module):
def __init__(self):
# These values are *not* tf.Variables.
self.bias = 0.
self.weight = 2.
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x)) # Didn't change :(
Adding bias! tf.Tensor(20.0, shape=(), dtype=float32)
如果使用相同的 Function
评估模型的更新实例,那么更新后的模型与原始模型将具有相同的缓存键,所以这种做法并不合理。
因此,建议您编写 Function
以避免依赖于可变对象特性,或者创建新对象。
如果这不可行,则一种解决方法是,每次修改对象时都创建新的 Function
以强制回溯:
def evaluate(model, x):
return model.weight * x + model.bias
new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
回溯可能十分耗费资源,您可以使用 tf.Variable
作为对象特性,可以对其进行改变(但非更改,请注意!) 以在无需回溯的情况下实现相似效果。
class BetterModel:
def __init__(self):
self.bias = tf.Variable(0.)
self.weight = tf.Variable(2.)
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5
print(evaluate(better_model, x)) # This works!
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
创建 tf.Variables
Function
仅支持在第一次调用时创建一次,并且在后续函数调用中重复使用的单例 tf.Variable
。下面的代码段会在每个函数调用中创建一个新的 tf.Variable
,这会导致 ValueError
异常。
示例:
@tf.function
def f(x):
v = tf.Variable(1.0)
return v
with assert_raises(ValueError):
f(1.0)
Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_176735/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_176735/3018268426.py", line 7, in <module> f(1.0) ValueError: in user code: File "/tmpfs/tmp/ipykernel_176735/3018268426.py", line 3, in f * v = tf.Variable(1.0) ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
用于解决这种限制的常见模式是从 Python None 值开始,随后,在值为 None 时,有条件地创建 tf.Variable
:
class Count(tf.Module):
def __init__(self):
self.count = None
@tf.function
def __call__(self):
if self.count is None:
self.count = tf.Variable(0)
return self.count.assign_add(1)
c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32)
与多个 Keras 优化器一起使用
将多个 Keras 优化器与 tf.function
一起使用时,您可能会遇到 ValueError: tf.function only supports singleton tf.Variables created on the first call.
。发生此错误的原因是优化器在首次应用梯度时会在内部创建 tf.Variables
。
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
@tf.function
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
train_step(w, x, y, opt2)
Calling `train_step` with different optimizer... Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_176735/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_176735/3167358578.py", line 18, in <module> train_step(w, x, y, opt2) ValueError: in user code: File "/tmpfs/tmp/ipykernel_176735/3167358578.py", line 9, in train_step * optimizer.apply_gradients(zip(gradients, [w])) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1140, in apply_gradients ** return super().apply_gradients(grads_and_vars, name=name) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 621, in apply_gradients self.build(trainable_variables) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py", line 139, in build self.add_variable_from_reference( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1072, in add_variable_from_reference return super().add_variable_from_reference( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 496, in add_variable_from_reference variable = tf.Variable( ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
如果您需要在训练期间更改优化器,一种解决方法是为每个优化器创建一个新的 Function
,直接调用 ConcreteFunction
。
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
# Not a tf.function.
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
if i % 2 == 0:
train_step_1(w, x, y) # `opt1` is not used as a parameter.
else:
train_step_2(w, x, y) # `opt2` is not used as a parameter.
与多个 Keras 模型一起使用
将不同的模型实例传递给同一 Function
时,您也可能会遇到 ValueError: tf.function only supports singleton tf.Variables created on the first call.
。
发生此错误的原因是 Keras 模型(未定义其输入形状)和 Keras 层会在首次调用时创建 tf.Variables
。您可能正在尝试在已调用的 Function
中初始化这些变量。为避免此错误,请在训练模型之前尝试调用 model.build(input_shape)
以初始化所有权重。
延伸阅读
要了解如何导出和加载 Function
,请参阅 SavedModel 指南。要详细了解跟踪后执行的计算图优化,请参阅 Grappler 指南。要了解如何优化数据流水线和剖析模型性能,请参阅 Profiler 指南。