TensorFlow graph optimization with Grappler

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

TensorFlow uses both graph and eager executions to execute computations. A tf.Graph contains a set of tf.Operation objects (ops) which represent units of computation and tf.Tensor objects which represent the units of data that flow between ops.

Grappler is the default graph optimization system in the TensorFlow runtime. Grappler applies optimizations in graph mode (within tf.function) to improve the performance of your TensorFlow computations through graph simplifications and other high-level optimizations such as inlining function bodies to enable inter-procedural optimizations. Optimizing the tf.Graph also reduces the device peak memory usage and improves hardware utilization by optimizing the mapping of graph nodes to compute resources.

Use tf.config.optimizer.set_experimental_options() for finer control over your tf.Graph optimizations.

Available graph optimizers

Grappler performs graph optimizations through a top-level driver called the MetaOptimizer. The following graph optimizers are available with TensorFlow:

  • Constant folding optimizer - Statically infers the value of tensors when possible by folding constant nodes in the graph and materializes the result using constants.
  • Arithmetic optimizer - Simplifies arithmetic operations by eliminating common subexpressions and simplifying arithmetic statements.
  • Layout optimizer - Optimizes tensor layouts to execute data format dependent operations such as convolutions more efficiently.
  • Remapper optimizer - Remaps subgraphs onto more efficient implementations by replacing commonly occurring subgraphs with optimized fused monolithic kernels.
  • Memory optimizer - Analyzes the graph to inspect the peak memory usage for each operation and inserts CPU-GPU memory copy operations for swapping GPU memory to CPU to reduce the peak memory usage.
  • Dependency optimizer - Removes or rearranges control dependencies to shorten the critical path for a model step or enables other optimizations. Also removes nodes that are effectively no-ops such as Identity.
  • Pruning optimizer - Prunes nodes that have no effect on the output from the graph. It is usually run first to reduce the size of the graph and speed up processing in other Grappler passes.
  • Function optimizer - Optimizes the function library of a TensorFlow program and inlines function bodies to enable other inter-procedural optimizations.
  • Shape optimizer - Optimizes subgraphs that operate on shape and shape related information.
  • Autoparallel optimizer - Automatically parallelizes graphs by splitting along the batch dimension. This optimizer is turned OFF by default.
  • Loop optimizer - Optimizes the graph control flow by hoisting loop-invariant subgraphs out of loops and by removing redundant stack operations in loops. Also optimizes loops with statically known trip counts and removes statically known dead branches in conditionals.
  • Scoped allocator optimizer - Introduces scoped allocators to reduce data movement and to consolidate some operations.
  • Pin to host optimizer - Swaps small operations onto the CPU. This optimizer is turned OFF by default.
  • Auto mixed precision optimizer - Converts data types to float16 where applicable to improve performance. Currently applies only to GPUs.
  • Debug stripper - Strips nodes related to debugging operations such as tf.debugging.Assert, tf.debugging.check_numerics, and tf.print from the graph. This optimizer is turned OFF by default.

Setup

import numpy as np
import timeit
import traceback
import contextlib


import tensorflow as tf
2024-07-19 02:31:31.091219: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-19 02:31:31.112299: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-19 02:31:31.118548: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

Create a context manager to easily toggle optimizer states.

@contextlib.contextmanager
def options(options):
  old_opts = tf.config.optimizer.get_experimental_options()
  tf.config.optimizer.set_experimental_options(options)
  try:
    yield
  finally:
    tf.config.optimizer.set_experimental_options(old_opts)

Compare execution performance with and without Grappler

TensorFlow 2 and beyond executes eagerly by default. Use tf.function to switch the default execution to Graph mode. Grappler runs automatically in the background to apply the graph optimizations above and improve execution performance.

Constant folding optimizer

As a preliminary example, consider a function which performs operations on constants and returns an output.

def test_function_1():
  @tf.function
  def simple_function(input_arg):
    print('Tracing!')
    a = tf.constant(np.random.randn(2000,2000), dtype = tf.float32)
    c = a
    for n in range(50):
      c = c@a
    return tf.reduce_mean(c+input_arg)

  return simple_function

Turn off the constant folding optimizer and execute the function:

with options({'constant_folding': False}):
  print(tf.config.optimizer.get_experimental_options())
  simple_function = test_function_1()
  # Trace once
  x = tf.constant(2.2)
  simple_function(x)
  print("Vanilla execution:", timeit.timeit(lambda: simple_function(x), number = 1), "s")
{'constant_folding': False, 'disable_model_pruning': False, 'disable_meta_optimizer': False}
Tracing!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1721356293.603556  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356293.607452  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356293.611198  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356293.614432  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356293.626336  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356293.629850  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356293.633353  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356293.636156  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356293.639683  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356293.643010  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356293.646513  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356293.649416  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.893828  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.895879  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.897860  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.899953  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.902068  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.903973  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.905811  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.907772  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.909769  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.911660  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.913492  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.915472  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.953722  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.955707  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.957610  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.959623  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.961631  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.963555  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.965419  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.967402  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.969409  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.971847  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.974144  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721356294.976532  112419 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
Vanilla execution: 0.0017471689998274087 s

Enable the constant folding optimizer and execute the function again to observe a speed-up in function execution.

with options({'constant_folding': True}):
  print(tf.config.optimizer.get_experimental_options())
  simple_function = test_function_1()
  # Trace once
  x = tf.constant(2.2)
  simple_function(x)
  print("Constant folded execution:", timeit.timeit(lambda: simple_function(x), number = 1), "s")
{'constant_folding': True, 'disable_model_pruning': False, 'disable_meta_optimizer': False}
Tracing!
Constant folded execution: 0.0008535509996363544 s

Debug stripper optimizer

Consider a simple function that checks the numeric value of its input argument and returns it.

def test_function_2():
  @tf.function
  def simple_func(input_arg):
    output = input_arg
    tf.debugging.check_numerics(output, "Bad!")
    return output
  return simple_func

First, execute the function with the debug stripper optimizer turned off.

test_func = test_function_2()
p1 = tf.constant(float('inf'))
try:
  test_func(p1)
except tf.errors.InvalidArgumentError as e:
  traceback.print_exc(limit=2)
2024-07-19 02:31:50.754627: E tensorflow/core/kernels/check_numerics_op.cc:299] abnormal_detected_host @0x7f49e6c00100 = {0, 1} Bad!
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_112419/3616845043.py", line 4, in <module>
    test_func(p1)
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node CheckNumerics defined at (most recent call last):
  File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main

  File "/usr/lib/python3.9/runpy.py", line 87, in _run_code

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py", line 18, in <module>

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/traitlets/config/application.py", line 1075, in launch_instance

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 739, in start

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 205, in start

  File "/usr/lib/python3.9/asyncio/base_events.py", line 601, in run_forever

  File "/usr/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once

  File "/usr/lib/python3.9/asyncio/events.py", line 80, in _run

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 534, in process_one

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 362, in execute_request

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 778, in execute_request

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 449, in do_execute

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 549, in run_cell

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3048, in run_cell

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes

  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code

  File "/tmpfs/tmp/ipykernel_112419/3616845043.py", line 4, in <module>

  File "/tmpfs/tmp/ipykernel_112419/2241890286.py", line 5, in simple_func

Bad! : Tensor had Inf values
     [[{ {node CheckNumerics} }]] [Op:__inference_simple_func_128]

tf.debugging.check_numerics raises an invalid argument error because of the Inf argument to test_func.

Enable the debug stripper optimizer and execute the function again.

with options({'debug_stripper': True}):
  test_func2 = test_function_2()
  p1 = tf.constant(float('inf'))
  try:
    test_func2(p1)
  except tf.errors.InvalidArgumentError as e:
    traceback.print_exc(limit=2)

The debug stripper optimizer strips the tf.debug.check_numerics node from the graph and executes the function without raising any errors.

Summary

The TensorFlow runtime uses Grappler to optimize graphs automatically before execution. Use tf.config.optimizer.set_experimental_options to enable or disable the various graph optimizers.

For more information on Grappler, see TensorFlow Graph Optimizations.