使用 Grappler 优化 TensorFlow 计算图

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本

概述

TensorFlow 同时使用计算图和 Eager Execution 来执行计算。一个 tf.Graph 包含一组代表计算单元的 tf.Operation 对象(运算)和一组代表在运算之间流动的数据单元的 tf.Tensor 对象。

Grappler 是 TensorFlow 运行时中的默认计算图优化系统。Grappler 通过计算图简化和其他高级优化(例如利用内嵌函数体实现程序间优化),在计算图模式(在 tf.function 内)下应用优化以提高 TensorFlow 计算的性能。优化 tf.Graph 还可以通过优化计算图节点到计算资源的映射来减少设备峰值内存使用量并提高硬件利用率。

使用 tf.config.optimizer.set_experimental_options() 可以更好地控制 tf.Graph 优化。

可用的计算图优化器

Grappler 通过称为 MetaOptimizer 的顶级驱动程序执行计算图优化。TensorFlow 提供以下计算图优化器:

  • 常量折叠优化器 - 通过折叠计算图中的常量节点来静态推断张量的值(如可能),并使用常量使结果具体化。
  • 算术优化器 - 通过消除常见的子表达式并简化算术语句来简化算术运算。
  • 布局优化器 - 优化张量布局以更高效地执行依赖于数据格式的运算,例如卷积。
  • 重新映射优化器 - 通过将常见的子计算图替换为经过优化的融合一体化内核,将子计算图重新映射到更高效的实现上。
  • 内存优化器 - 分析计算图以检查每个运算的峰值内存使用量,并插入 CPU-GPU 内存复制操作以将 GPU 内存交换到 CPU,从而减少峰值内存使用量。
  • 依赖项优化器 - 移除或重新排列控制依赖项,以缩短模型步骤的关键路径或实现其他优化。另外,还移除了实际上是无运算的节点,例如 Identity。
  • 剪枝优化器 - 修剪对计算图的输出没有影响的节点。通常会首先运行剪枝来减小计算图的大小并加快其他 Grappler 传递中的处理速度。
  • 函数优化器 - 优化 TensorFlow 程序的函数库,并内嵌函数体以实现其他程序间优化。
  • 形状优化器 - 优化对形状和形状相关信息进行运算的子计算图。
  • 自动并行优化器 - 通过沿批次维度拆分来自动并行化计算图。默认情况下,此优化器处于关闭状态。
  • 循环优化器 - 通过将循环不变式子计算图提升到循环外并通过移除循环中的冗余堆栈运算来优化计算图控制流。另外,还优化具有静态已知行程计数的循环,并移除条件语句中静态已知的无效分支。
  • 范围分配器优化器 - 引入范围分配器以减少数据移动并合并某些运算。
  • 固定到主机优化器 - 将小型运算交换到 CPU 上。默认情况下,此优化器处于关闭状态。
  • 自动混合精度优化器 - 在适用的情况下将数据类型转换为 float16 以提高性能。目前仅适用于 GPU。
  • 调试剥离器 - 从计算图中剥离与调试运算相关的节点,例如 tf.debugging.Asserttf.debugging.check_numericstf.print。默认情况下,此优化器处于关闭状态。

设置

import numpy as np
import timeit
import traceback
import contextlib


import tensorflow as tf
2022-12-14 21:15:47.322566: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 21:15:47.322681: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 21:15:47.322698: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

创建上下文管理器以轻松切换优化器状态。

@contextlib.contextmanager
def options(options):
  old_opts = tf.config.optimizer.get_experimental_options()
  tf.config.optimizer.set_experimental_options(options)
  try:
    yield
  finally:
    tf.config.optimizer.set_experimental_options(old_opts)

比较使用和不使用 Grappler 时的执行性能

TensorFlow 2 及更高版本默认情况下会以 Eager 模式执行。使用 tf.function 可将默认执行切换为“计算图”模式。Grappler 在后台自动运行,以应用上述计算图优化并提高执行性能。

常量折叠优化器

作为一个初步的示例,考虑一个对常量执行运算并返回输出的函数。

def test_function_1():
  @tf.function
  def simple_function(input_arg):
    print('Tracing!')
    a = tf.constant(np.random.randn(2000,2000), dtype = tf.float32)
    c = a
    for n in range(50):
      c = c@a
    return tf.reduce_mean(c+input_arg)

  return simple_function

关闭常量折叠优化器并执行以下函数:

with options({'constant_folding': False}):
  print(tf.config.optimizer.get_experimental_options())
  simple_function = test_function_1()
  # Trace once
  x = tf.constant(2.2)
  simple_function(x)
  print("Vanilla execution:", timeit.timeit(lambda: simple_function(x), number = 1), "s")
{'constant_folding': False, 'disable_model_pruning': False, 'disable_meta_optimizer': False}
Tracing!
Vanilla execution: 0.0011523299999680603 s

启用常量折叠优化器,然后再次执行函数以观察函数执行的加速情况。

with options({'constant_folding': True}):
  print(tf.config.optimizer.get_experimental_options())
  simple_function = test_function_1()
  # Trace once
  x = tf.constant(2.2)
  simple_function(x)
  print("Constant folded execution:", timeit.timeit(lambda: simple_function(x), number = 1), "s")
{'constant_folding': True, 'disable_model_pruning': False, 'disable_meta_optimizer': False}
Tracing!
Constant folded execution: 0.0009006219997900189 s

调试剥离器优化器

考虑一个检查其输入参数的数值并返回自身的简单函数。

def test_function_2():
  @tf.function
  def simple_func(input_arg):
    output = input_arg
    tf.debugging.check_numerics(output, "Bad!")
    return output
  return simple_func

首先,在调试剥离器优化器关闭的情况下执行该函数。

test_func = test_function_2()
p1 = tf.constant(float('inf'))
try:
  test_func(p1)
except tf.errors.InvalidArgumentError as e:
  traceback.print_exc(limit=2)
2022-12-14 21:16:07.794488: E tensorflow/core/kernels/check_numerics_op.cc:293] abnormal_detected_host @0x7f49f6600100 = {0, 1} Bad!
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_95249/3616845043.py", line 4, in <module>
    test_func(p1)
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node 'CheckNumerics' defined at (most recent call last):
    File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/usr/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/traitlets/config/application.py", line 992, in launch_instance
      app.start()
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 711, in start
      self.io_loop.start()
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
      self._run_once()
    File "/usr/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
      handle._run()
    File "/usr/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 411, in do_execute
      res = shell.run_cell(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 531, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2940, in run_cell
      result = self._run_cell(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2995, in _run_cell
      return runner(coro)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3194, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3373, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmpfs/tmp/ipykernel_95249/3616845043.py", line 4, in <module>
      test_func(p1)
    File "/tmpfs/tmp/ipykernel_95249/2241890286.py", line 5, in simple_func
      tf.debugging.check_numerics(output, "Bad!")
Node: 'CheckNumerics'
Bad! : Tensor had Inf values
     [[{ {node CheckNumerics} }]] [Op:__inference_simple_func_131]

由于 test_funcInf 参数,tf.debugging.check_numerics 引发了参数无效错误。

启用调试剥离器优化器,然后再次执行该函数。

with options({'debug_stripper': True}):
  test_func2 = test_function_2()
  p1 = tf.constant(float('inf'))
  try:
    test_func2(p1)
  except tf.errors.InvalidArgumentError as e:
    traceback.print_exc(limit=2)

调试剥离器优化器从计算图中剥离 tf.debug.check_numerics 节点并执行该函数,而不会引发任何错误。

总结

TensorFlow 运行时会在执行之前使用 Grappler 自动优化计算图。使用 tf.config.optimizer.set_experimental_options 可启用或停用各个计算图优化器。

有关 Grappler 的更多信息,请参阅 TensorFlow 计算图优化