{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "headers" }, "source": [ "Project: /overview/_project.yaml\n", "Book: /overview/_book.yaml\n", "\n", "\n", "\n", "\n", "\n", "\n", "{% comment %}\n", "The source of truth file can be found [here]: http://google3/zz\n", "{% endcomment %}" ] }, { "cell_type": "markdown", "metadata": { "id": "metadata" }, "source": [ "
在 TensorFlow.org 上查看\n", " | \n", "在 Google Colab 中运行 | \n", "在 GitHub 上查看源代码\n", " | \n", "下载笔记本 | \n", "
tf.function
将程序转换为计算图。这是一个转换工具,用于从 Python 代码创建独立于 Python 的数据流图。它可以帮助您创建高效且可移植的模型,并且如果要使用 `SavedModel`,则必须使用此工具。\n",
"\n",
"本指南介绍 tf.function
的底层工作原理,让您形成概念化理解,从而有效地加以利用。\n",
"\n",
"要点和建议包括:\n",
"\n",
"- 先在 Eager 模式下调试,然后使用 @tf.function
进行装饰。\n",
"- 不依赖 Python 的副作用,如对象变异或列表追加。\n",
"- tf.function
最适合处理 TensorFlow 运算;NumPy 和 Python 调用会转换为常量。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SjvqpgepHJPd"
},
"source": [
"## 设置"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:45.600636Z",
"iopub.status.busy": "2022-12-14T22:33:45.600427Z",
"iopub.status.idle": "2022-12-14T22:33:49.321477Z",
"shell.execute_reply": "2022-12-14T22:33:49.320762Z"
},
"id": "otIdN1TS8N7S"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-12-14 22:33:48.348405: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
"2022-12-14 22:33:48.348501: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
"2022-12-14 22:33:48.348510: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
]
}
],
"source": [
"# Update TensorFlow, as this notebook requires version 2.9 or later\n",
"!pip install -q -U tensorflow>=2.9.0\n",
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "I0xDjO4SHLUD"
},
"source": [
"定义一个辅助函数来演示可能遇到的错误类型:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:49.325894Z",
"iopub.status.busy": "2022-12-14T22:33:49.325486Z",
"iopub.status.idle": "2022-12-14T22:33:49.330651Z",
"shell.execute_reply": "2022-12-14T22:33:49.329892Z"
},
"id": "D25apou9IOXa"
},
"outputs": [],
"source": [
"import traceback\n",
"import contextlib\n",
"\n",
"# Some helper code to demonstrate the kinds of errors you might encounter.\n",
"@contextlib.contextmanager\n",
"def assert_raises(error_class):\n",
" try:\n",
" yield\n",
" except error_class as e:\n",
" print('Caught expected exception \\n {}:'.format(error_class))\n",
" traceback.print_exc(limit=2)\n",
" except Exception as e:\n",
" raise e\n",
" else:\n",
" raise Exception('Expected {} to be raised but no error was raised!'.format(\n",
" error_class))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WPSfepzTHThq"
},
"source": [
"## 基础知识"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CNwYTIJ8r56W"
},
"source": [
"### 用法\n",
"\n",
"您定义的 `Function`(例如,通过应用 @tf.function
装饰器)就像核心 TensorFlow 运算:您可以在 Eager 模式下执行它,可以计算梯度,等等。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:49.334339Z",
"iopub.status.busy": "2022-12-14T22:33:49.333916Z",
"iopub.status.idle": "2022-12-14T22:33:52.718791Z",
"shell.execute_reply": "2022-12-14T22:33:52.717889Z"
},
"id": "SbtT1-Wm70F2"
},
"outputs": [
{
"data": {
"text/plain": [
"tf.function
!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nhpUtRqsXoyM"
},
"source": [
"#### 什么是“跟踪”?\n",
"\n",
"`Function` 在 [TensorFlow 计算图](https://tensorflow.google.cn/guide/intro_to_graphs#what_are_graphs)中运行您的程序。但是,tf.Graph
不能代表您在 Eager TensorFlow 程序中编写的全部内容。例如,Python 支持多态,但是 tf.Graph
要求其输入具有指定的数据类型和维度。或者,您可能执行辅助任务,例如读取命令行参数、引发错误或使用更复杂的 Python 对象。这些内容均不能在 tf.Graph
中运行。\n",
"\n",
"`Function` 通过将代码分为以下两个阶段填补了这一空缺:\n",
"\n",
"1. 第一阶段称为**跟踪**,在这一阶段中,`Function` 会创建新的 tf.Graph
。Python 代码可以正常运行,但是所有 TensorFlow 运算(例如添加两个张量)都会被*推迟*:它们会被 tf.Graph
捕获而不运行。\n",
"\n",
"2. 在第二阶段中,将运行包含第一阶段中推迟的全部内容的 tf.Graph
。此阶段比跟踪阶段快得多。\n",
"\n",
"根据输入,`Function` 在调用时并非总会运行第一阶段。请参阅下方的[跟踪规则](#rules_of_tracing)以更好地了解其决定方式。跳过第一阶段并仅执行第二阶段,可以实现 TensorFlow 的高性能。\n",
"\n",
"当 `Function` 决定跟踪时,在跟踪阶段完成后会立即运行第二阶段,因此调用 `Function` 会创建并运行 tf.Graph
。稍后,您将了解如何使用 [`get_concrete_function`](#obtaining_concrete_functions) 来仅运行跟踪阶段。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K7scSzLx662f"
},
"source": [
"当您将不同类型的参数传递给 `Function` 时,两个阶段都将运行:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:54.102822Z",
"iopub.status.busy": "2022-12-14T22:33:54.102181Z",
"iopub.status.idle": "2022-12-14T22:33:54.155667Z",
"shell.execute_reply": "2022-12-14T22:33:54.154939Z"
},
"id": "kojmJrgq8U9v"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing with Tensor(\"a:0\", shape=(), dtype=int32)\n",
"tf.Tensor(2, shape=(), dtype=int32)\n",
"\n",
"Tracing with Tensor(\"a:0\", shape=(), dtype=float32)\n",
"tf.Tensor(2.2, shape=(), dtype=float32)\n",
"\n",
"Tracing with Tensor(\"a:0\", shape=(), dtype=string)\n",
"tf.Tensor(b'aa', shape=(), dtype=string)\n",
"\n"
]
}
],
"source": [
"@tf.function\n",
"def double(a):\n",
" print(\"Tracing with\", a)\n",
" return a + a\n",
"\n",
"print(double(tf.constant(1)))\n",
"print()\n",
"print(double(tf.constant(1.1)))\n",
"print()\n",
"print(double(tf.constant(\"a\")))\n",
"print()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QPfouGUQrcNb"
},
"source": [
"请注意,如果重复使用同一参数类型调用 `Function`,TensorFlow 会跳过跟踪阶段并重用之前跟踪的计算图,因为后面的调用生成的计算图可能相同。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:54.158987Z",
"iopub.status.busy": "2022-12-14T22:33:54.158389Z",
"iopub.status.idle": "2022-12-14T22:33:54.162749Z",
"shell.execute_reply": "2022-12-14T22:33:54.162093Z"
},
"id": "hFccbWFRrsBp"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(b'bb', shape=(), dtype=string)\n"
]
}
],
"source": [
"# This doesn't print 'Tracing with ...'\n",
"print(double(tf.constant(\"b\")))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fgIO_XEzcB9o"
},
"source": [
"您可以使用 `pretty_printed_concrete_signatures()` 查看所有可用跟踪记录:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:54.165957Z",
"iopub.status.busy": "2022-12-14T22:33:54.165489Z",
"iopub.status.idle": "2022-12-14T22:33:54.169328Z",
"shell.execute_reply": "2022-12-14T22:33:54.168699Z"
},
"id": "IiQc4IKAb-NX"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"double(a)\n",
" Args:\n",
" a: int32 Tensor, shape=()\n",
" Returns:\n",
" int32 Tensor, shape=()\n",
"\n",
"double(a)\n",
" Args:\n",
" a: float32 Tensor, shape=()\n",
" Returns:\n",
" float32 Tensor, shape=()\n",
"\n",
"double(a)\n",
" Args:\n",
" a: string Tensor, shape=()\n",
" Returns:\n",
" string Tensor, shape=()\n"
]
}
],
"source": [
"print(double.pretty_printed_concrete_signatures())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rKQ92VEWI7n8"
},
"source": [
"目前,您已经了解 tf.function
通过 TensorFlow 的计算图跟踪逻辑创建缓存的动态调度层。对于术语的含义,更具体的解释如下:\n",
"\n",
"- tf.Graph
与语言无关,是 TensorFlow 计算的原始可移植表示。\n",
"- `ConcreteFunction` 封装 tf.Graph
。\n",
"- `Function` 管理 `ConcreteFunction` 的缓存,并为输入选择正确的缓存。\n",
"- tf.function
包装 Python 函数,并返回一个 `Function` 对象。\n",
"- **跟踪**会创建 tf.Graph
并将其封装在 `ConcreteFunction` 中,也称为**跟踪**。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "129-iRsPS-gY"
},
"source": [
"#### 跟踪规则\n",
"\n",
"被调用时,`Function` 使用每个参数的 tf.types.experimental.TraceType
将调用参数与现有的 `ConcreteFunction` 匹配。如果找到匹配的 `ConcreteFunction`,则将调用分派给它。如果未找到匹配项,则跟踪新的 `ConcreteFunction`。\n",
"\n",
"如果找到多个匹配项,则会选择最具体的签名。匹配是通过[子类型化](https://en.wikipedia.org/wiki/Subtyping)完成的,就像 C++ 或 Java 中的普通函数调用一样。例如,`TensorShape([1, 2])` 是 `TensorShape([None, None])` 的子类型,因此可以将使用 `TensorShape([1, 2])` 对 tf.function 进行的调用分派到使用 `TensorShape([None, None])` 生成的 `ConcreteFunction`。但是,如果具有 `TensorShape([1, None])` 的 `ConcreteFunction` 也存在,那么它将被优先考虑,因为它更具体。\n",
"\n",
"`TraceType` 由输入参数确定,具体如下所示:\n",
"\n",
"- 对于 `Tensor`,类型由 `Tensor` 的 `dtype` 和 `shape` 参数化;有秩形状是无秩形状的子类型;固定维度是未知维度的子类型\n",
"- 对于 `Variable`,类型类似于 `Tensor`,但还包括变量的唯一资源 ID,这是正确连接控制依赖项所必需的\n",
"- 对于 Python 基元值,类型对应于**值**本身。例如,值为 `3` 的 `TraceType` 是 `LiteralTraceType<3>`,而不是 `int`。\n",
"- 对于 `list` 和 `tuple` 等 Python 有序容器,类型是通过其元素的类型来参数化的;例如,`[1, 2]` 的类型是 `ListTraceTypetf.function
时快。\n",
"\n",
"要控制跟踪行为,可以采用以下技巧:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EUtycWJa34TT"
},
"source": [
"#### 将固定的 `input_signature` 传递给 tf.function
"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:54.172975Z",
"iopub.status.busy": "2022-12-14T22:33:54.172455Z",
"iopub.status.idle": "2022-12-14T22:33:54.221083Z",
"shell.execute_reply": "2022-12-14T22:33:54.220474Z"
},
"id": "_BDMIRmu1RGB"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing with Tensor(\"x:0\", shape=(None,), dtype=int32)\n",
"tf.Tensor([4 1], shape=(2,), dtype=int32)\n",
"Caught expected exception \n",
" tf.experimental.ExtensionType
。此外,`ExtensionType` 的 `TraceType` 是与其关联的 tf.TypeSpec
。因此,如果需要,您只需重写默认的 tf.TypeSpec
即可控制 `ExtensionType` 的 `Tracing Protocol`。请参阅[扩展程序类型](extension_type.ipynb)指南中的*自定义 ExtensionType 的 TypeSpec*部分以了解详情。\n",
"\n",
"否则,要直接控制 `Function` 何时应针对特定 Python 类型进行重新跟踪,您可以自行为其实现 `Tracing Protocol`。"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:54.519255Z",
"iopub.status.busy": "2022-12-14T22:33:54.518673Z",
"iopub.status.idle": "2022-12-14T22:33:54.575404Z",
"shell.execute_reply": "2022-12-14T22:33:54.574786Z"
},
"id": "gZkIh7UaIKc6"
},
"outputs": [
{
"data": {
"text/plain": [
"tf.Graph
的可调用包装器。虽然一般不需要检索实际 tf.Graph
对象,不过,您可以从任何具体函数轻松获得实际对象。"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:54.654558Z",
"iopub.status.busy": "2022-12-14T22:33:54.654355Z",
"iopub.status.idle": "2022-12-14T22:33:54.658037Z",
"shell.execute_reply": "2022-12-14T22:33:54.657493Z"
},
"id": "5UENeGHfaX8g"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[] -> a\n",
"['a', 'a'] -> add\n",
"['add'] -> Identity\n"
]
}
],
"source": [
"graph = double_strings.graph\n",
"for node in graph.as_graph_def().node:\n",
" print(f'{node.input} -> {node.name}')\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aIKkgr6qdtp4"
},
"source": [
"### 调试\n",
"\n",
"通常,在 Eager 模式下调试代码比在 tf.function
中简单。在使用 tf.function
进行装饰之前,进行装饰之前,您应该先确保代码可在 Eager 模式下无错误执行。为了帮助调试,您可以调用 tf.config.run_functions_eagerly(True)
来全局停用和重新启用 tf.function
。\n",
"\n",
"追溯仅在 tf.function
中出现的问题时,可参考下面的几点提示:\n",
"\n",
"- 普通旧 Python `print` 调用仅在跟踪期间执行,可以帮助您在(重新)跟踪函数时进行追溯。\n",
"- tf.print
调用每次都会执行,可用于追溯执行过程中产生的中间值。\n",
"- 利用 tf.debugging.enable_check_numerics
很容易追溯到 NaN 和 Inf 在何处创建。\n",
"- `pdb`([Python 调试器](https://docs.python.org/3/library/pdb.html))可以帮助您理解跟踪的详细过程。(提醒:使用 `pdb` 调试时,AutoGraph 会自动转换 Python 源代码。)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5f05Vr_YBUCz"
},
"source": [
"## AutoGraph 转换\n",
"\n",
"AutoGraph 是一个库,在 tf.function
中默认处于启用状态。它可以将 Python Eager 代码的子集转换为与计算图兼容的 TensorFlow 运算。这包括 `if`、`for`、`while` 等控制流。\n",
"\n",
"tf.cond
和 tf.while_loop
等 TensorFlow 运算仍然可以运行,但是使用 Python 编写时,控制流通常更易于编写,代码也更易于理解。"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:54.661694Z",
"iopub.status.busy": "2022-12-14T22:33:54.661130Z",
"iopub.status.idle": "2022-12-14T22:33:54.767768Z",
"shell.execute_reply": "2022-12-14T22:33:54.767099Z"
},
"id": "yCQTtTPTW3WF"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.682211161 0.396621943 0.451262951 0.643357158 0.87304759]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.592955 0.37705484 0.422936589 0.567181051 0.702919185]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.532017589 0.360147029 0.399401426 0.513286 0.606217444]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.486921817 0.3453435 0.379436702 0.472501546 0.541458964]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.451769888 0.332239449 0.362218171 0.4402183 0.49409157]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.423352748 0.320531547 0.347166359 0.413825333 0.45745784]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.399751157 0.309987456 0.333860159 0.391715884 0.428010017]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.379735976 0.300425678 0.321985 0.372838497 0.40365687]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.362478137 0.291702092 0.311300635 0.356472 0.383073539]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.347394943 0.283700645 0.301619858 0.342102677 0.365373671]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.334063202 0.276326627 0.292794317 0.329353303 0.349938482]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.322167 0.269501895 0.284704626 0.31793955 0.336320966]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.311465025 0.263161272 0.277253687 0.307642668 0.324188948]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.301769286 0.257249981 0.270361394 0.298291 0.313289642]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.292930901 0.251721531 0.263961077 0.289747804 0.303426832]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.284830183 0.24653624 0.257996708 0.281902701 0.294445485]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.277369589 0.241659909 0.252420813 0.274665147 0.286221236]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.270468801 0.237062961 0.247192904 0.26796037 0.278653115]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.264061 0.23271966 0.242278129 0.261725903 0.271658033]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.25809 0.228607446 0.237646371 0.255909115 0.265166938]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.252508163 0.224706501 0.23327139 0.250465214 0.259121925]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.24727492 0.22099933 0.229130313 0.245355889 0.253474057]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.242355347 0.217470333 0.225202918 0.240548164 0.248181522]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.237719223 0.214105651 0.221471429 0.236013427 0.243208483]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.233340293 0.210892901 0.21792005 0.231726721 0.23852399]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.229195595 0.207821 0.214534715 0.227666199 0.234101087]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.225264892 0.20487988 0.211302832 0.22381258 0.229916304]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.221530348 0.202060521 0.208213195 0.220148876 0.225948915]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.217976168 0.199354753 0.205255613 0.216659933 0.222180739]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.214588255 0.196755111 0.20242089 0.213332266 0.218595564]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.211354017 0.19425483 0.199700788 0.210153803 0.215179041]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.208262146 0.191847727 0.197087735 0.207113713 0.211918324]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.205302477 0.189528167 0.194574893 0.204202175 0.20880191]\n"
]
},
{
"data": {
"text/plain": [
"tf.cond
跟踪并将条件的两个分支添加到计算图,在执行时动态选择分支。跟踪可能产生意外的副作用;请参阅 [AutoGraph 跟踪作用](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#effects-of-the-tracing-process)以了解详情。"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:54.778451Z",
"iopub.status.busy": "2022-12-14T22:33:54.777908Z",
"iopub.status.idle": "2022-12-14T22:33:54.977784Z",
"shell.execute_reply": "2022-12-14T22:33:54.976938Z"
},
"id": "BOQl8PMq2Sf3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing for loop\n",
"Tracing fizzbuzz branch\n",
"Tracing fizz branch\n",
"Tracing buzz branch\n",
"Tracing default branch\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fizz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"buzz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fizz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"buzz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fizz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"7\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"8\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fizz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"buzz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"11\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fizz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"13\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"14\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fizzbuzz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"16\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"17\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fizz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"19\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"buzz\n"
]
}
],
"source": [
"@tf.function\n",
"def fizzbuzz(n):\n",
" for i in tf.range(1, n + 1):\n",
" print('Tracing for loop')\n",
" if i % 15 == 0:\n",
" print('Tracing fizzbuzz branch')\n",
" tf.print('fizzbuzz')\n",
" elif i % 3 == 0:\n",
" print('Tracing fizz branch')\n",
" tf.print('fizz')\n",
" elif i % 5 == 0:\n",
" print('Tracing buzz branch')\n",
" tf.print('buzz')\n",
" else:\n",
" print('Tracing default branch')\n",
" tf.print(i)\n",
"\n",
"fizzbuzz(tf.constant(5))\n",
"fizzbuzz(tf.constant(20))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4rBO5AQ15HVC"
},
"source": [
"有关 AutoGraph 转换的 if 语句的其他限制,请参阅[参考文档](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#if-statements)。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yho4J0a0ZkQS"
},
"source": [
"### 循环\n",
"\n",
"AutoGraph 会将某些 `for` 和 `while` 语句转换为等效的 TensorFlow 循环运算,例如 tf.while_loop
。如果不转换,则会将 `for` 或 `while` 循环作为 Python 循环执行。\n",
"\n",
"以下情形会执行这种替换:\n",
"\n",
"- `for x in y`:如果 `y` 是一个张量,则转换为 tf.while_loop
。在特殊情况下,如果 `y` 是 tf.data.Dataset
,则会生成 tf.data.Dataset
运算的组合。\n",
"- `while tf.while_loop
。\n",
"\n",
"Python 循环在跟踪时执行,因而循环每迭代一次,都会将额外的运算添加到 tf.Graph
。\n",
"\n",
"TensorFlow 循环会跟踪循环体,并在执行时动态选择迭代的运行次数。循环体仅在生成的 tf.Graph
中出现一次。\n",
"\n",
"有关 AutoGraph 转换的 `for` 和 `while` 语句的其他限制,请参阅[参考文档](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#while-statements)。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sp4rbIdfbM6s"
},
"source": [
"#### 在 Python 数据上循环\n",
"\n",
"一个常见陷阱是在 tf.function
中的 Python/Numpy 数据上循环。此循环在跟踪过程中执行,因而循环每迭代一次,都会将模型的一个副本添加到 tf.Graph
。\n",
"\n",
"如果要在 tf.function
中包装整个训练循环,最安全的方法是将数据包装为 tf.data.Dataset
,以便 AutoGraph 动态展开训练循环。"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:54.981886Z",
"iopub.status.busy": "2022-12-14T22:33:54.981234Z",
"iopub.status.idle": "2022-12-14T22:33:55.116667Z",
"shell.execute_reply": "2022-12-14T22:33:55.115867Z"
},
"id": "WGZ19LspbZ27"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph\n",
"train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train(tf.data.Dataset.from_generator
与 ` tf.data.Dataset.from_tensors`。前者将数据保留在 Python 中,并通过 tf.py_function
获取,这可能会影响性能;后者将数据的副本捆绑成计算图中的一个大 tf.constant()
节点,这可能会消耗较多内存。\n",
"\n",
"通过 `TFRecordDataset`、`CsvDataset` 等从文件中读取数据是最高效的数据使用方式,因为这样 TensorFlow 就可以自行管理数据的异步加载和预提取,不必利用 Python。要了解详细信息,请参阅 [`tf.data`:构建 TensorFlow 输入流水线](../../guide/data)指南。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hyksHW9TCukR"
},
"source": [
"#### 累加循环值\n",
"\n",
"一种常见模式是不断累加循环的中间值。通常,这可以通过将元素追加到 Python 列表或将条目添加到 Python 字典来实现。但是,由于存在 Python 副作用,在动态展开循环中,这些方法无法达到预期效果。要从动态展开循环累加结果,可以使用 tf.TensorArray
来实现。"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:55.120229Z",
"iopub.status.busy": "2022-12-14T22:33:55.119943Z",
"iopub.status.idle": "2022-12-14T22:33:55.261054Z",
"shell.execute_reply": "2022-12-14T22:33:55.260194Z"
},
"id": "HJ3Vb3dXfefN"
},
"outputs": [
{
"data": {
"text/plain": [
"tf.Graph
,而不执行 Python 代码。\n",
"\n",
"一般经验法则是避免在逻辑中依赖 Python 副作用,而仅使用它们来调试跟踪记录。否则,TensorFlow API(例如 tf.data
、tf.print
、tf.summary
、tf.Variable.assign
和 tf.TensorArray
)是确保在每次调用时 TensorFlow 运行时都能执行您的代码的最佳方式。"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:55.264407Z",
"iopub.status.busy": "2022-12-14T22:33:55.263946Z",
"iopub.status.idle": "2022-12-14T22:33:55.298465Z",
"shell.execute_reply": "2022-12-14T22:33:55.297686Z"
},
"id": "w2sACuZ9TTRk"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Traced with 1\n",
"Executed with 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Executed with 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Traced with 2\n",
"Executed with 2\n"
]
}
],
"source": [
"@tf.function\n",
"def f(x):\n",
" print(\"Traced with\", x)\n",
" tf.print(\"Executed with\", x)\n",
"\n",
"f(1)\n",
"f(1)\n",
"f(2)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e1I0dPiqTV8H"
},
"source": [
"如果希望在每次调用 `Function` 时都执行 Python 代码,tf.py_function
可以作为退出点。tf.py_function
的缺点是不可移植,性能不高,无法使用 SavedModel 保存并且在分布式(多 GPU、TPU)设置中效果不佳。另外,由于 tf.py_function
必须连接到计算图中,它会将所有输入/输出转换为张量。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bOW1v9WVKGgH"
},
"source": [
"#### 更改 Python 全局变量和自由变量\n",
"\n",
"更改 Python 全局变量和[自由变量](https://docs.python.org/3/reference/executionmodel.html#binding-of-names)视为 Python 副作用,因此仅在跟踪期间发生。"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:55.301987Z",
"iopub.status.busy": "2022-12-14T22:33:55.301675Z",
"iopub.status.idle": "2022-12-14T22:33:55.325428Z",
"shell.execute_reply": "2022-12-14T22:33:55.324655Z"
},
"id": "7aJD--9qTWmg"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Python side effect\n"
]
}
],
"source": [
"external_list = []\n",
"\n",
"@tf.function\n",
"def side_effect(x):\n",
" print('Python side effect')\n",
" external_list.append(x)\n",
"\n",
"side_effect(1)\n",
"side_effect(1)\n",
"side_effect(1)\n",
"# The list append only happened once!\n",
"assert len(external_list) == 1"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5eZTFRv_k_nR"
},
"source": [
"有时很难注意到意外行为。在下面的示例中,`counter` 旨在保护变量的增量。然而,由于它是一个 Python 整数而不是 TensorFlow 对象,它的值在第一次跟踪期间被捕获。使用 tf.function
时,`assign_add` 将被无条件记录在底层计算图中。因此,每次调用 tf.function
时 `v` 都会增加 1。当使用 Python 副作用(示例中的 `counter`)确定要运行的运算(示例中的 `assign_add`)时,此问题在尝试使用 tf.function
装饰器将其计算图模式 Tensorflow 代码迁移到 Tensorflow 2 的用户中十分常见。通常,用户只有在看到可疑的数值结果或明显低于预期的性能(例如,如果受保护运算的开销非常大)后才会意识到这一点。"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:55.328676Z",
"iopub.status.busy": "2022-12-14T22:33:55.328084Z",
"iopub.status.idle": "2022-12-14T22:33:55.374956Z",
"shell.execute_reply": "2022-12-14T22:33:55.374125Z"
},
"id": "5r6p7-9jk_3L"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"2\n",
"3\n"
]
}
],
"source": [
"class Model(tf.Module):\n",
" def __init__(self):\n",
" self.v = tf.Variable(0)\n",
" self.counter = 0\n",
"\n",
" @tf.function\n",
" def __call__(self):\n",
" if self.counter == 0:\n",
" # A python side-effect\n",
" self.counter += 1\n",
" self.v.assign_add(1)\n",
"\n",
" return self.v\n",
"\n",
"m = Model()\n",
"for n in range(3):\n",
" print(m().numpy()) # prints 1, 2, 3"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tXCTcHoVcxhX"
},
"source": [
"实现预期行为的一种解决方法是使用 [`tf.init_scope`](https://tensorflow.google.cn/api_docs/python/tf/init_scope) 将运算提升到函数计算图以外。这样可以确保变量增量在跟踪期间只执行一次。应当注意的是,`init_scope` 还有其他副作用,包括清除控制流和梯度带。有时 `init_scope` 的使用会变得过于复杂而无法实际管理。"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:55.378434Z",
"iopub.status.busy": "2022-12-14T22:33:55.377841Z",
"iopub.status.idle": "2022-12-14T22:33:55.425119Z",
"shell.execute_reply": "2022-12-14T22:33:55.424358Z"
},
"id": "An4MrIbrcvi8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"1\n",
"1\n"
]
}
],
"source": [
"class Model(tf.Module):\n",
" def __init__(self):\n",
" self.v = tf.Variable(0)\n",
" self.counter = 0\n",
"\n",
" @tf.function\n",
" def __call__(self):\n",
" if self.counter == 0:\n",
" # Lifts ops out of function-building graphs\n",
" with tf.init_scope():\n",
" self.counter += 1\n",
" self.v.assign_add(1)\n",
"\n",
" return self.v\n",
"\n",
"m = Model()\n",
"for n in range(3):\n",
" print(m().numpy()) # prints 1, 1, 1"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pbFG5CX4LwQA"
},
"source": [
"总之,根据经验,您应避免改变整数或容器(如位于 `Function` 外部的列表)等 Python 对象,而应使用参数和 TF 对象。例如,[在循环中累加值](#accumulating_values_in_a_loop)部分中提供了一个如何实现类列表运算的示例。\n",
"\n",
"在某些情况下,如果为 [`tf.Variable`](https://tensorflow.google.cn/guide/variable),则您可以捕获和处理状态。这是通过重复调用相同的 `ConcreteFunction` 来更新 Keras 模型权重的方式。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X_oNNGrAqPJ1"
},
"source": [
"#### 使用 Python 迭代器和生成器"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "msTmv-oyUNaf"
},
"source": [
"很多 Python 功能(如生成器和迭代器)依赖 Python 运行时来跟踪状态。通常,虽然这些构造在 Eager 模式下可以正常工作,但它们是 Python 副作用的示例,因此仅在跟踪期间发生。"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:55.428858Z",
"iopub.status.busy": "2022-12-14T22:33:55.428245Z",
"iopub.status.idle": "2022-12-14T22:33:55.456458Z",
"shell.execute_reply": "2022-12-14T22:33:55.455847Z"
},
"id": "FNPD4unZUedH"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Value: 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Value: 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Value: 1\n"
]
}
],
"source": [
"@tf.function\n",
"def buggy_consume_next(iterator):\n",
" tf.print(\"Value:\", next(iterator))\n",
"\n",
"iterator = iter([1, 2, 3])\n",
"buggy_consume_next(iterator)\n",
"# This reuses the first value from the iterator, rather than consuming the next value.\n",
"buggy_consume_next(iterator)\n",
"buggy_consume_next(iterator)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wcS3TAgCjTWR"
},
"source": [
"就像 TensorFlow 具有用于列表构造的专用 tf.TensorArray
一样,它也具有用于迭代构造的专用 tf.data.Iterator
。有关概述,请参阅 [AutoGraph 转换](#autograph_transformations)部分。此外,[`tf.data`](https://tensorflow.google.cn/guide/data) API 也可帮助实现生成器模式:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:55.459883Z",
"iopub.status.busy": "2022-12-14T22:33:55.459339Z",
"iopub.status.idle": "2022-12-14T22:33:55.498504Z",
"shell.execute_reply": "2022-12-14T22:33:55.497863Z"
},
"id": "8D_iKetXW6VE"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Value: 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Value: 2\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Value: 3\n"
]
}
],
"source": [
"@tf.function\n",
"def good_consume_next(iterator):\n",
" # This is ok, iterator is a tf.data.Iterator\n",
" tf.print(\"Value:\", next(iterator))\n",
"\n",
"ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])\n",
"iterator = iter(ds)\n",
"good_consume_next(iterator)\n",
"good_consume_next(iterator)\n",
"good_consume_next(iterator)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i8YAMYb6KEh4"
},
"source": [
"### tf.function 的所有输出都必须是返回值\n",
"\n",
"除了 tf.Variable
外,一个 tf.function 必须返回其所有输出。尝试直接从函数访问任何张量而不遍历返回值会导致“泄漏”。\n",
"\n",
"例如,下面的函数通过 Python 全局变量 `x`“泄漏”张量 `a`:"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:55.502054Z",
"iopub.status.busy": "2022-12-14T22:33:55.501792Z",
"iopub.status.idle": "2022-12-14T22:33:55.533576Z",
"shell.execute_reply": "2022-12-14T22:33:55.533005Z"
},
"id": "zrdp4rjxg6jo"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3\n",
"'Tensor' object has no attribute 'numpy'\n"
]
}
],
"source": [
"x = None\n",
"\n",
"@tf.function\n",
"def leaky_function(a):\n",
" global x\n",
" x = a + 1 # Bad - leaks local tensor\n",
" return a + 2\n",
"\n",
"correct_a = leaky_function(tf.constant(1))\n",
"\n",
"print(correct_a.numpy()) # Good - value obtained from function's returns\n",
"try:\n",
" x.numpy() # Bad - tensor leaked from inside the function, cannot be used here\n",
"except AttributeError as expected:\n",
" print(expected)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-d4_J_DC5rxX"
},
"source": [
"即使同时返回泄漏的值时也是如此:"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:55.536863Z",
"iopub.status.busy": "2022-12-14T22:33:55.536299Z",
"iopub.status.idle": "2022-12-14T22:33:55.815261Z",
"shell.execute_reply": "2022-12-14T22:33:55.814571Z"
},
"id": "PrcpPB8C5s9T"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2\n",
"'Tensor' object has no attribute 'numpy'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Caught expected exception \n",
" tf.Variable
并改用 Variable.assign
方法。"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:56.775080Z",
"iopub.status.busy": "2022-12-14T22:33:56.774520Z",
"iopub.status.idle": "2022-12-14T22:33:56.803620Z",
"shell.execute_reply": "2022-12-14T22:33:56.802964Z"
},
"id": "oeJMdXd3M0cc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Variable: tf.Tensor(2, shape=(), dtype=int32)\n"
]
}
],
"source": [
"@tf.function\n",
"def variable_add():\n",
" return 1 + foo\n",
"\n",
"foo = tf.Variable(1)\n",
"print(\"Variable:\", variable_add())\n"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:56.806561Z",
"iopub.status.busy": "2022-12-14T22:33:56.806337Z",
"iopub.status.idle": "2022-12-14T22:33:56.811617Z",
"shell.execute_reply": "2022-12-14T22:33:56.811015Z"
},
"id": "L3q7sUJWZOSd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Updating the value of `foo` to 100!\n",
"Variable: tf.Tensor(101, shape=(), dtype=int32)\n"
]
}
],
"source": [
"print(\"Updating the value of `foo` to 100!\")\n",
"foo.assign(100)\n",
"print(\"Variable:\", variable_add())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hvwe9gTIWfx6"
},
"source": [
"#### 取决于 Python 对象"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BJkZS-SwPvOQ"
},
"source": [
"将 Python 对象作为参数传递给 tf.function
的建议存在许多已知问题,预计会在以后得到解决。通常,如果您使用 Python 基元或兼容 tf.nest
的结构作为参数,或将对象的*不同*实例传递给 `Function`,则可以依赖稳定的跟踪。但是,如果您传递**同一对象并仅更改其特性**时,`Function` 将*不会*创建新的跟踪记录。"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:56.815130Z",
"iopub.status.busy": "2022-12-14T22:33:56.814534Z",
"iopub.status.idle": "2022-12-14T22:33:56.847798Z",
"shell.execute_reply": "2022-12-14T22:33:56.846985Z"
},
"id": "ux8KJESVWDxX"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(20.0, shape=(), dtype=float32)\n"
]
}
],
"source": [
"class SimpleModel(tf.Module):\n",
" def __init__(self):\n",
" # These values are *not* tf.Variables.\n",
" self.bias = 0.\n",
" self.weight = 2.\n",
"\n",
"@tf.function\n",
"def evaluate(model, x):\n",
" return model.weight * x + model.bias\n",
"\n",
"simple_model = SimpleModel()\n",
"x = tf.constant(10.)\n",
"print(evaluate(simple_model, x))"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:56.851366Z",
"iopub.status.busy": "2022-12-14T22:33:56.850747Z",
"iopub.status.idle": "2022-12-14T22:33:56.855158Z",
"shell.execute_reply": "2022-12-14T22:33:56.854588Z"
},
"id": "mUxRF4ghZZvX"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Adding bias!\n",
"tf.Tensor(20.0, shape=(), dtype=float32)\n"
]
}
],
"source": [
"print(\"Adding bias!\")\n",
"simple_model.bias += 5.0\n",
"print(evaluate(simple_model, x)) # Didn't change :("
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ytcgg2qFWaBF"
},
"source": [
"如果使用相同的 `Function` 评估模型的更新实例,那么更新后的模型与原始模型将具有[相同的缓存键](#rules_of_tracing),所以这种做法并不合理。\n",
"\n",
"因此,建议您编写 `Function` 以避免依赖于可变对象特性,或者创建新对象。\n",
"\n",
"如果这不可行,则一种解决方法是,每次修改对象时都创建新的 `Function` 以强制回溯:"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:56.858164Z",
"iopub.status.busy": "2022-12-14T22:33:56.857940Z",
"iopub.status.idle": "2022-12-14T22:33:56.890062Z",
"shell.execute_reply": "2022-12-14T22:33:56.889389Z"
},
"id": "pFvWmWAAQjrv"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(20.0, shape=(), dtype=float32)\n"
]
}
],
"source": [
"def evaluate(model, x):\n",
" return model.weight * x + model.bias\n",
"\n",
"new_model = SimpleModel()\n",
"evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)\n",
"# Don't pass in `new_model`, `Function` already captured its state during tracing.\n",
"print(evaluate_no_bias(x)) "
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:56.893315Z",
"iopub.status.busy": "2022-12-14T22:33:56.892824Z",
"iopub.status.idle": "2022-12-14T22:33:56.908076Z",
"shell.execute_reply": "2022-12-14T22:33:56.907402Z"
},
"id": "bdU2-jF4ZH0B"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Adding bias!\n",
"tf.Tensor(25.0, shape=(), dtype=float32)\n"
]
}
],
"source": [
"print(\"Adding bias!\")\n",
"new_model.bias += 5.0\n",
"# Create new Function and ConcreteFunction since you modified new_model.\n",
"evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)\n",
"print(evaluate_with_bias(x)) # Don't pass in `new_model`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uFgEZClsZrEi"
},
"source": [
"[回溯可能十分耗费资源](https://tensorflow.google.cn/guide/intro_to_graphs#tracing_and_performance),您可以使用 tf.Variable
作为对象特性,可以对其进行改变(但非更改,请注意!) 以在无需回溯的情况下实现相似效果。"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:56.911373Z",
"iopub.status.busy": "2022-12-14T22:33:56.910806Z",
"iopub.status.idle": "2022-12-14T22:33:56.948046Z",
"shell.execute_reply": "2022-12-14T22:33:56.947370Z"
},
"id": "daAP_lucwS6w"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(20.0, shape=(), dtype=float32)\n"
]
}
],
"source": [
"class BetterModel:\n",
"\n",
" def __init__(self):\n",
" self.bias = tf.Variable(0.)\n",
" self.weight = tf.Variable(2.)\n",
"\n",
"@tf.function\n",
"def evaluate(model, x):\n",
" return model.weight * x + model.bias\n",
"\n",
"better_model = BetterModel()\n",
"print(evaluate(better_model, x))\n"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:56.951182Z",
"iopub.status.busy": "2022-12-14T22:33:56.950614Z",
"iopub.status.idle": "2022-12-14T22:33:56.956431Z",
"shell.execute_reply": "2022-12-14T22:33:56.955850Z"
},
"id": "ktqwMJBqwTFj"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Adding bias!\n",
"tf.Tensor(25.0, shape=(), dtype=float32)\n"
]
}
],
"source": [
"print(\"Adding bias!\")\n",
"better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5\n",
"print(evaluate(better_model, x)) # This works!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lPr_6mK_AQWL"
},
"source": [
"### 创建 tf.Variables\n",
"\n",
"`Function` 仅支持在第一次调用时创建一次,并且在后续函数调用中重复使用的单例 tf.Variable
。下面的代码段会在每个函数调用中创建一个新的 tf.Variable
,这会导致 `ValueError` 异常。\n",
"\n",
"示例:"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:56.959466Z",
"iopub.status.busy": "2022-12-14T22:33:56.959187Z",
"iopub.status.idle": "2022-12-14T22:33:57.004957Z",
"shell.execute_reply": "2022-12-14T22:33:57.004319Z"
},
"id": "Tx0Vvnb_9OB-"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Caught expected exception \n",
" tf.Variable
:"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:57.008463Z",
"iopub.status.busy": "2022-12-14T22:33:57.007868Z",
"iopub.status.idle": "2022-12-14T22:33:57.073908Z",
"shell.execute_reply": "2022-12-14T22:33:57.073272Z"
},
"id": "HQrG5_kOiKl_"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(1, shape=(), dtype=int32)\n",
"tf.Tensor(2, shape=(), dtype=int32)\n"
]
}
],
"source": [
"class Count(tf.Module):\n",
" def __init__(self):\n",
" self.count = None\n",
"\n",
" @tf.function\n",
" def __call__(self):\n",
" if self.count is None:\n",
" self.count = tf.Variable(0)\n",
" return self.count.assign_add(1)\n",
"\n",
"c = Count()\n",
"print(c())\n",
"print(c())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7uD6qI7aJwbR"
},
"source": [
"#### 与多个 Keras 优化器一起使用\n",
"\n",
"将多个 Keras 优化器与 tf.function
一起使用时,您可能会遇到 `ValueError: tf.function only supports singleton tf.Variables created on the first call.`。发生此错误的原因是优化器在首次应用梯度时会在内部创建 `tf.Variables`。"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:33:57.077415Z",
"iopub.status.busy": "2022-12-14T22:33:57.076905Z",
"iopub.status.idle": "2022-12-14T22:33:57.382371Z",
"shell.execute_reply": "2022-12-14T22:33:57.381647Z"
},
"id": "yWQ3-r99Jvze"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Calling `train_step` with different optimizer...\n",
"Caught expected exception \n",
"