# 计算图和 tf.function 简介

`tf.function` 简介" />

## 概述

### 计算图的优点

• 通过在计算中折叠常量节点来静态推断张量的值（“常量折叠”）
• 分离独立的计算子部分，并在线程或设备之间进行拆分。
• 通过消除通用子表达式来简化算术运算。

## 安装

``````import tensorflow as tf
import timeit
from datetime import datetime
``````
```2023-11-07 17:38:22.381032: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-07 17:38:22.381078: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-07 17:38:22.382629: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
```

## 利用计算图

``````# Define a Python function.
def a_regular_function(x, y, b):
x = tf.matmul(x, y)
x = x + b
return x

# `a_function_that_uses_a_graph` is a TensorFlow `Function`.
a_function_that_uses_a_graph = tf.function(a_regular_function)

# Make some tensors.
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)

orig_value = a_regular_function(x1, y1, b1).numpy()
# Call a `Function` like a Python function.
tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()
assert(orig_value == tf_function_value)
``````

`tf.function` 适用于一个函数及其调用的所有其他函数

``````def inner_function(x, y, b):
x = tf.matmul(x, y)
x = x + b
return x

# Use the decorator to make `outer_function` a `Function`.
@tf.function
def outer_function(x):
y = tf.constant([[2.0], [3.0]])
b = tf.constant(4.0)

return inner_function(x, y, b)

# Note that the callable will create a graph that
# includes `inner_function` as well as `outer_function`.
outer_function(tf.constant([[1.0, 2.0]])).numpy()
``````
```array([[12.]], dtype=float32)
```

### 将 Python 函数转换为计算图

``````def simple_relu(x):
if tf.greater(x, 0):
return x
else:
return 0

# `tf_simple_relu` is a TensorFlow `Function` that wraps `simple_relu`.
tf_simple_relu = tf.function(simple_relu)

print("First branch, with graph:", tf_simple_relu(tf.constant(1)).numpy())
print("Second branch, with graph:", tf_simple_relu(tf.constant(-1)).numpy())
``````
```First branch, with graph: 1
Second branch, with graph: 0
```

``````# This is the graph-generating output of AutoGraph.
print(tf.autograph.to_code(simple_relu))
``````
```def tf__simple_relu(x):
with ag__.FunctionScope('simple_relu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
do_return = False
retval_ = ag__.UndefinedReturnValue()

def get_state():
return (do_return, retval_)

def set_state(vars_):
nonlocal retval_, do_return
(do_return, retval_) = vars_

def if_body():
nonlocal retval_, do_return
try:
do_return = True
retval_ = ag__.ld(x)
except:
do_return = False
raise

def else_body():
nonlocal retval_, do_return
try:
do_return = True
retval_ = 0
except:
do_return = False
raise
ag__.if_stmt(ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(x), 0), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
return fscope.ret(retval_, do_return)
```
``````# This is the graph itself.
print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())
``````
```node {
name: "x"
op: "Placeholder"
attr {
key: "_user_specified_name"
value {
s: "x"
}
}
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
}
}
}
}
node {
name: "Greater/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node {
name: "Greater"
op: "Greater"
input: "x"
input: "Greater/y"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "cond"
op: "StatelessIf"
input: "Greater"
input: "x"
attr {
key: "Tcond"
value {
type: DT_BOOL
}
}
attr {
key: "Tin"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_BOOL
type: DT_INT32
}
}
}
attr {
key: "_lower_using_switch_merge"
value {
b: true
}
}
attr {
value {
list {
}
}
}
attr {
key: "else_branch"
value {
func {
name: "cond_false_31"
}
}
}
attr {
key: "output_shapes"
value {
list {
shape {
}
shape {
}
}
}
}
attr {
key: "then_branch"
value {
func {
name: "cond_true_30"
}
}
}
}
node {
name: "cond/Identity"
op: "Identity"
input: "cond"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node {
name: "cond/Identity_1"
op: "Identity"
input: "cond:1"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "Identity"
op: "Identity"
input: "cond/Identity_1"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
library {
function {
signature {
name: "cond_false_31"
input_arg {
name: "cond_placeholder"
type: DT_INT32
}
output_arg {
name: "cond_identity"
type: DT_BOOL
}
output_arg {
name: "cond_identity_1"
type: DT_INT32
}
}
node_def {
name: "cond/Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_BOOL
tensor_shape {
}
bool_val: true
}
}
}
}
node_def {
name: "cond/Const_1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_BOOL
tensor_shape {
}
bool_val: true
}
}
}
}
node_def {
name: "cond/Const_2"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node_def {
name: "cond/Const_3"
op: "Const"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_BOOL
tensor_shape {
}
bool_val: true
}
}
}
}
node_def {
name: "cond/Identity"
op: "Identity"
input: "cond/Const_3:output:0"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node_def {
name: "cond/Const_4"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node_def {
name: "cond/Identity_1"
op: "Identity"
input: "cond/Const_4:output:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
ret {
key: "cond_identity"
value: "cond/Identity:output:0"
}
ret {
key: "cond_identity_1"
value: "cond/Identity_1:output:0"
}
attr {
key: "_construction_context"
value {
s: "kEagerRuntime"
}
}
arg_attr {
key: 0
value {
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
}
}
}
function {
signature {
name: "cond_true_30"
input_arg {
name: "cond_identity_1_x"
type: DT_INT32
}
output_arg {
name: "cond_identity"
type: DT_BOOL
}
output_arg {
name: "cond_identity_1"
type: DT_INT32
}
}
node_def {
name: "cond/Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_BOOL
tensor_shape {
}
bool_val: true
}
}
}
}
node_def {
name: "cond/Identity"
op: "Identity"
input: "cond/Const:output:0"
attr {
key: "T"
value {
type: DT_BOOL
}
}
}
node_def {
name: "cond/Identity_1"
op: "Identity"
input: "cond_identity_1_x"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
ret {
key: "cond_identity"
value: "cond/Identity:output:0"
}
ret {
key: "cond_identity_1"
value: "cond/Identity_1:output:0"
}
attr {
key: "_construction_context"
value {
s: "kEagerRuntime"
}
}
arg_attr {
key: 0
value {
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
attr {
key: "_user_specified_name"
value {
s: "x"
}
}
}
}
}
}
versions {
producer: 1645
min_consumer: 12
}
```

### 多态性：一个 `Function`，多个计算图

`tf.Graph` 专门用于特定类型的输入（例如，具有特定 `dtype` 的张量或具有相同 `id()` 的对象）。

`Function``ConcreteFunction` 中存储与该签名对应的 `tf.Graph``ConcreteFunction` 是围绕 `tf.Graph` 的封装容器

``````@tf.function
def my_relu(x):
return tf.maximum(0., x)

# `my_relu` creates new graphs as it observes more signatures.
print(my_relu(tf.constant(5.5)))
print(my_relu([1, -1]))
print(my_relu(tf.constant([3., -3.])))
``````
```tf.Tensor(5.5, shape=(), dtype=float32)
tf.Tensor([1. 0.], shape=(2,), dtype=float32)
tf.Tensor([3. 0.], shape=(2,), dtype=float32)
```

``````# These two calls do *not* create new graphs.
print(my_relu(tf.constant(-2.5))) # Signature matches `tf.constant(5.5)`.
print(my_relu(tf.constant([-1., 1.]))) # Signature matches `tf.constant([3., -3.])`.
``````
```tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor([0. 1.], shape=(2,), dtype=float32)
```

``````# There are three `ConcreteFunction`s (one for each graph) in `my_relu`.
# The `ConcreteFunction` also knows the return type and shape!
print(my_relu.pretty_printed_concrete_signatures())
``````
```Input Parameters:
x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)
Output Type:
TensorSpec(shape=(), dtype=tf.float32, name=None)
Captures:
None

Input Parameters:
x (POSITIONAL_OR_KEYWORD): List[Literal[1], Literal[-1]]
Output Type:
TensorSpec(shape=(2,), dtype=tf.float32, name=None)
Captures:
None

Input Parameters:
x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(2,), dtype=tf.float32, name=None)
Output Type:
TensorSpec(shape=(2,), dtype=tf.float32, name=None)
Captures:
None
```

## 使用 `tf.function`

### 计算图执行与 Eager Execution

`Function` 函数中的代码既能以 Eager 模式执行，也可以作为计算图执行。默认情况下，`Function` 将其代码作为计算图执行：

``````@tf.function
def get_MSE(y_true, y_pred):
sq_diff = tf.pow(y_true - y_pred, 2)
return tf.reduce_mean(sq_diff)
``````
``````y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)
y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)
print(y_true)
print(y_pred)
``````
```tf.Tensor([2 5 4 5 3], shape=(5,), dtype=int32)
tf.Tensor([2 4 9 9 4], shape=(5,), dtype=int32)
```
``````get_MSE(y_true, y_pred)
``````
```<tf.Tensor: shape=(), dtype=int32, numpy=8>
```

``````tf.config.run_functions_eagerly(True)
``````
``````get_MSE(y_true, y_pred)
``````
```<tf.Tensor: shape=(), dtype=int32, numpy=8>
```
``````# Don't forget to set it back when you are done.
tf.config.run_functions_eagerly(False)
``````

``````@tf.function
def get_MSE(y_true, y_pred):
print("Calculating MSE!")
sq_diff = tf.pow(y_true - y_pred, 2)
return tf.reduce_mean(sq_diff)
``````

``````error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
``````
```Calculating MSE!
```

``````# Now, globally set everything to run eagerly to force eager execution.
tf.config.run_functions_eagerly(True)
``````
``````# Observe what is printed below.
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
``````
```Calculating MSE!
Calculating MSE!
Calculating MSE!
```
``````tf.config.run_functions_eagerly(False)
``````

`print`Python 的副作用，在将函数转换为 `Function` 时，您还应注意其他差异。请在使用 `tf.function` 提升性能指南中的限制部分中了解详情。

### 非严格执行

``````def unused_return_eager(x):
# Get index 1 will fail when `len(x) == 1`
tf.gather(x, [1]) # unused
return x

try:
print(unused_return_eager(tf.constant([0.0])))
except tf.errors.InvalidArgumentError as e:
# All operations are run during eager execution so an error is raised.
print(f'{type(e).__name__}: {e}')
``````
```tf.Tensor([0.], shape=(1,), dtype=float32)
```
``````@tf.function
def unused_return_graph(x):
tf.gather(x, [1]) # unused
return x

# Only needed operations are run during graph execution. The error is not raised.
print(unused_return_graph(tf.constant([0.0])))
``````
```tf.Tensor([0.], shape=(1,), dtype=float32)
```

### `tf.function` 最佳做法

`tf.function` 设计可能是您编写与计算图兼容的 TensorFlow 程序的最佳选择。以下是一些提示：

## 见证加速

`tf.function` 通常可以提高代码的性能，但加速的程度取决于您运行的计算种类。小型计算可能以调用计算图的开销为主。您可以按如下方式衡量性能上的差异：

``````x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)

def power(x, y):
result = tf.eye(10, dtype=tf.dtypes.int32)
for _ in range(y):
result = tf.matmul(x, result)
return result
``````
``````print("Eager execution:", timeit.timeit(lambda: power(x, 100), number=1000), "seconds")
``````
```Eager execution: 4.110958347000178 seconds
```
``````power_as_graph = tf.function(power)
print("Graph execution:", timeit.timeit(lambda: power_as_graph(x, 100), number=1000), "seconds")
``````
```Graph execution: 0.8174858239999594 seconds
```

`tf.function` 通常用于加速训练循环，您可以在使用 Keras 从头开始编写训练循环指南的使用 `tf.function` 加速训练步骤部分中了解详情。

## `Function` 何时进行跟踪？

``````@tf.function
def a_function_with_python_side_effect(x):
print("Tracing!") # An eager-only side effect.
return x * x + tf.constant(2)

# This is traced the first time.
print(a_function_with_python_side_effect(tf.constant(2)))
# The second time through, you won't see the side effect.
print(a_function_with_python_side_effect(tf.constant(3)))
``````
```Tracing!
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)
```
``````# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter.
print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))
``````
```Tracing!
tf.Tensor(6, shape=(), dtype=int32)
Tracing!
tf.Tensor(11, shape=(), dtype=int32)
```

[]
[]