tf.function によるパフォーマンスの改善

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

TensorFlow 2 の Eager execution はデフォルトで有効になっています。ユーザーインターフェースは直感的で柔軟性に優れていますが(一度限りの演算の実行ははるかに簡単で高速に行われます)、パフォーマンスとデプロイ能力に影響がでることがあります。

プログラムからグラフを作成するには、tf.function を使用できます。変換ツールで Python コードから Python に依存しないデータフローグラフを作成するため、パフォーマンスと移植性に優れたモデルを作成できます。また、SavedModel を使用する際に必要となります。

このチュートリアルでは tf.function と AutoGraph の基本的な特徴についてひととおり確認します。

主に次の内容と推奨事項について説明しています。

  • Eager モードでデバッグしてから、@tf.function でデコレートする。
  • オブジェクトミューテーションまたはリストの追加といった Python 側の効果に依存しないこと。
  • tf.function は TensorFlow 演算子と最も相性が良く、NumPy と Python 呼び出しは定数に変換される。

セットアップ

import tensorflow as tf
2024-01-11 19:25:24.397188: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 19:25:24.397236: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 19:25:24.398868: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

発生する可能性のあるエラーの種類を示すヘルパー関数を定義します。

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

基礎

使い方

定義する Function@tf.function デコレーターを適用するなどして)は、コアの TensorFlow 演算とまったく変わりません。Eager での実行や勾配の計算などを行えます。

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Function をほかの Function 内で使用できます。

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function は、特に小さな演算が多数含まれるグラフでは、Eager コードよりも高速に実行されることがありますが、高価な演算がいくつか含まれるグラフ(畳み込みなど)では、速度の差はあまり見られません。

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.005722227000660496
Function conv: 0.0052488470000753296
Note how there's not much difference in performance for convolutions

トレーシング

このセクションは、Function の内部動作や実装の詳細を説明します。将来的に変更する可能性がありますが、いつなぜトレーシングが発生するのかを理解しておけば、tf.function を効果的に使用しやすくなります。

「トレーシング」とは?

FunctionTensorFlow Graph でプログラムを実行しますが、tf.Graph は、Eager TensorFlow プログラムにユーザーが記述するすべてのものを表現することはできません。たとえば、Python はポリモーフィズムをサポートしていますが、tf.Graph では、その入力に特定のデータ型と次元が必要です。またはコマンドラインの引数を読み取る、エラーを発生させる、より複雑な Python オブジェクトを扱うといったサイドタスクを実施しようとしても、どれも tf.Graph で実行することはできません。

Function はコードを 2 つの段階に分けることで、このギャップの橋渡しの役割を果たします。

  1. トレーシング」と呼ばれる第 1 段階において、Function は新しい tf.Graph を作成します。Python コードは通常通り実行しますが、すべての TensorFlow 演算(2 つのテンソルを加算するなど)は 据え置きとなります。これらは tf.Graph にとらわれるため、実行しません。

  2. 第 2 段階では、最初の段階で据え置きとなったすべての演算を含む tf.Graph が実行されます。この段階は、トレーシングの段階よりもはるかに高速に行われます。

Function は、その入力によっては必ずしも最初の段階で呼び出されたときに実行するわけではありません。この判定がどのように行われるのかについては、以下の「トレーシングの規則」をご覧ください。最初の段階を省略して 2 番目の段階のみを実行できれば、TensorFlow の高いパフォーマンスが発揮されます。

Function がトレーシングしないと判断した場合、トレーシング段階の直後に 第 2 段階が始まるため、Function を呼び出すと、tf.Graph の作成と実行が行われます。後の方で、get_concrete_function を使ってトレーシング段階のみを実行する方法を説明します。

型の異なる引数を Function に渡すと、両方の段階が実行されます。

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

同じ型の引数で Function を繰り返し呼び出すと、生成されるグラフはまったく同じになるため、TensorFlow はトレーシング段階を省略して前にトレーシングしたグラフを再利用することに注意してください。

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

すべての利用可能なトレースを確認するには、pretty_printed_concrete_signatures() を使用できます。

print(double.pretty_printed_concrete_signatures())
Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.int32, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.int32, name=None)
Captures:
  None

Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.float32, name=None)
Captures:
  None

Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.string, name=None)
Captures:
  None

ここまで、tf.function が TensorFlow のグラフトレーシングロジックにキャッシュされた動的ディスパッチレイヤーを作成するのを見てきました。用語についてより具体的に説明すると、次のように言えます。

  • tf.Graph は、言語に依存しない、生の移植可能な TensorFlow 計算の表現です。
  • ConcreteFunctiontf.Graph をラップします。
  • FunctionConcreteFunction のキャッシュを管理し、入力に適したものを選択します。
  • tf.function は Python 関数をラップし、Function オブジェクトを返します。
  • トレーシングtf.Graph を作成し、それを ConcreteFunction(またはトレース)をラップします。

トレーシングの規則

Function が呼び出されると、各引数の tf.types.experimental.TraceType を使用して呼び出し引数を既存の ConcreteFunction に一致させます。一致する ConcreteFunction が見つかった場合、呼び出しはそれにディスパッチされます。一致するものが見つからない場合、新しい ConcreteFunction がトレースされます。

複数の一致が見つかった場合は、最も具体的なシグネチャが選択されます。マッチングは、たとえば C++ や Java での通常の関数呼び出しと同じように、サブタイプ化によって行われます。例えば、TensorShape([1, 2])TensorShape([None, None]) のサブタイプ化であるため、TensorShape([1, 2]) を使用した tf.function への呼び出しは、TensorShape([None, None]) で生成された ConcreteFunction にディスパッチできます。しかし、TensorShape([1, None]) を持つ ConcreteFunction も存在する場合は、より具体的であるため優先されます。

TraceType は、次のように入力引数から決定されます。

  • Tensor の場合、型は Tensordtypeshape によってパラメータ化されます。階数付けされた形状は、階数付けされていない形状のサブタイプです。固定次元は未知次元のサブタイプです

  • Variable の場合、型は Tensor に似ていますが、変数の一意のリソース ID も含まれています。これは、制御の依存関係を正しく設定するために必要です。

  • Python プリミティブ値の場合、型は自体に対応します。たとえば、値 3TraceType は、int ではなく LiteralTraceType<3> です。

  • listtuple などの Python の順序付きコンテナの場合、型はそれらの要素の型によってパラメータ化されます。たとえば、[1, 2] の型は ListTraceType<LiteralTraceType<1>, LiteralTraceType<2>> であり、[2, 1] の型は ListTraceType<LiteralTraceType<2>, LiteralTraceType<1>> であり、異なります。

  • dict などの Python マッピングの場合、型も同じキーからのマッピングですが、実際の値ではなく値の型へのマッピングです。たとえば、{1: 2, 3: 4} の型は MappingTraceType<<KeyValue<1, LiteralTraceType<2>>>, <KeyValue<3, LiteralTraceType<4>>>> です。ただし、順序付きコンテナとは異なり、{1: 2, 3: 4}{3: 4, 1: 2} の型は同等です。

  • __tf_tracing_type__ メソッドを実装する Python オブジェクトの場合、型はそのメソッドが返すものです

  • その他の Python オブジェクトの場合、型はジェネリックの TraceType で、マッチング手順は以下のとおりです。

    • まず、オブジェクトが前のトレースで使用されたオブジェクトと同じであるかをチェックします(Python の id() または is を使用します)。オブジェクトが変更された場合でも一致することに注意してください。そのため、Python オブジェクトを tf.function の引数として使用する場合、イミュータブルを使用するのが最適です。
    • 次に、オブジェクトが前のトレースで使用されたオブジェクトと同じであるかをチェックします(Python の == を使用)。

    この手順では、オブジェクトへの weakref のみが維持されるため、オブジェクトが範囲内または削除されていない場合にのみ機能します。

注意: TraceTypeFunction 入力パラメータに基づいているため、グローバル変数と自由変数を変更するだけでは、新しいトレースは作成されません。Python のグローバル変数と自由変数を扱う際の推奨される方法については、こちらのセクションをご覧ください。

リトレーシングの制御

リトレーシングは、Function が 2 つ以上のトレースを作成する際に発生します。これは、TensorFlow が一連の入力ごとに正しいグラフを生成する上で役立ちますが、トレーシングは高価な演算です!Function が呼び出しごとに新しいグラフをリトレーシングすると、コードの実行は tf.function を使用しない場合よりも遅くなってしまいます。

トレーシングの動作を制御するには、次のテクニックを使用できます。

固定の input_signaturetf.function に渡す

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(TypeError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(TypeError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'TypeError'>:
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_178944/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_178944/3657259638.py", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2, 2), dtype=tf.int32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[1, 2],
       [3, 4]], dtype=int32)>,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_178944/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_178944/3657259638.py", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2,), dtype=tf.float32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 2.], dtype=float32)>,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).

柔軟性のために未知の次元を使用する

TensorFlow は形状に基づいてテンソルを一致させるため、ワイルドカードとして None 次元を使用することで、Function が可変サイズの入力にトレースを再利用できるようになります。可変サイズの入力は、長さの異なるシーケンスがある場合や、バッチごとに画像のサイズが異なる場合に発生します(例として、TransformerDeep Dream チュートリアルをご覧ください)。

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

Python リテラルの代わりにテンソルを渡す

通常、Python 引数は、num_layers=10 または training=True または nonlinearity='relu' などのように、ハイパーパラメータとグラフ構造の制御に使用されます。そのため、Python 引数が変わると、当然グラフをリトレースする必要が出てきます。

しかし、Python 引数がグラフ構造の制御に使用されていない場合もあります。こういった場合、Python の値の変化によってリトレーシングがトリガーされますが、これは不要です。この、AutoGraph が動的にアンロールするトレーニングループを例に見てみましょう。トレースが何度も行われますが、生成されたグラフはまったく同じであるため、リトレーシングは不要と言えます。

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

リトレーシングを強制する必要がある場合は、新しい Function を作成します。トレースは絶対に、各 Function オブジェクト間で共有されることはありません。

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

トレースプロトコルを使用する

可能であれば、代わりに Python 型を tf.experimental.ExtensionType に変換することをお勧めします。さらに、ExtensionTypeTraceType は、それに関連付けられた tf.TypeSpec です。したがって、必要に応じて、デフォルトの tf.TypeSpec を単純にオーバーライドして、ExtensionTypeTracing Protocol を制御できます。詳細については、拡張型ガイドの ExtensionType の TypeSpec のカスタマイズセクションをご覧ください。

それ以外の場合は、Function が特定の Python 型に関していつ再トレースする必要があるかを直接制御するために、Tracing Protocol を自分で実装できます。

@tf.function
def get_mixed_flavor(fruit_a, fruit_b):
  return fruit_a.flavor + fruit_b.flavor

class Fruit:
  flavor = tf.constant([0, 0])

class Apple(Fruit):
  flavor = tf.constant([1, 2])

class Mango(Fruit):
  flavor = tf.constant([3, 4])

# As described in the above rules, a generic TraceType for `Apple` and `Mango`
# is generated (and a corresponding ConcreteFunction is traced) but it fails to
# match the second function call since the first pair of Apple() and Mango()
# have gone out out of scope by then and deleted.
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again

# However, each subclass of the `Fruit` class has a fixed flavor, and you
# can reuse an existing traced concrete function if it was the same
# subclass. Avoiding such unnecessary tracing of concrete functions
# can have significant performance benefits.

class FruitTraceType(tf.types.experimental.TraceType):
  def __init__(self, fruit):
    self.fruit_type = type(fruit)
    self.fruit_value = fruit

  def is_subtype_of(self, other):
      # True if self subtypes `other` and `other`'s type matches FruitTraceType.
      return (type(other) is FruitTraceType and
              self.fruit_type is other.fruit_type)

  def most_specific_common_supertype(self, others):
      # `self` is the specific common supertype if all input types match it.
      return self if all(self == other for other in others) else None

  def placeholder_value(self, placeholder_context=None):
      # Use the fruit itself instead of the type for correct tracing.
      return self.fruit_value

  def __eq__(self, other):
    return type(other) is FruitTraceType and self.fruit_type == other.fruit_type

  def __hash__(self):
    return hash(self.fruit_type)

class FruitWithTraceType:

  def __tf_tracing_type__(self, context):
    return FruitTraceType(self)

class AppleWithTraceType(FruitWithTraceType):
  flavor = tf.constant([1, 2])

class MangoWithTraceType(FruitWithTraceType):
  flavor = tf.constant([3, 4])

# Now if you try calling it again:
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Traces a new concrete function
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Re-uses the traced concrete function
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 6], dtype=int32)>

具象関数の取得

関数がトレースされるたびに新しい具象関数が作成されますが、get_concrete_function を使うことで、具象関数を直接取得できます。

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)

ConcreteFunction を出力すると、入力引数(型付き)とその出力型の概要が表示されます。

print(double_strings)
ConcreteFunction Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.string, name=None)
Captures:
  None

また、具象関数のシグネチャを直接取得することもできます。

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

互換性のない型で具象トレースを使用すると、エラーが発生します。

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py", line 442, in bind_function_inputs
    bound_arguments = function_type.bind_with_defaults(
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/core/function/polymorphism/function_type.py", line 277, in bind_with_defaults
    with_default_args[arg_name] = constraint.cast(
TypeError: Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1180, in _call_impl
    return self._call_with_structured_signature(args, kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1260, in _call_with_structured_signature
    function_type_utils.canonicalize_function_inputs(
TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)`. Received args: (<tf.Tensor: shape=(), dtype=int32, numpy=1>,) and kwargs: {} for signature: (a: TensorSpec(shape=(), dtype=tf.string, name=None)) -> TensorSpec(shape=(), dtype=tf.string, name=None).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_178944/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_178944/3196284684.py", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

Python 引数は、具象関数の入力シグネチャで特別に扱われていることに気づいたかもしれません。TensorFlow 2.3 より前では、Python 引数は単に具象関数のシグネチャから削除されていましたが、TensorFlow 2.3 からはシグネチャに残されたまま、トレーシング中に値セットを取るように制約されています。

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)
  b (POSITIONAL_OR_KEYWORD): Literal[2]
Output Type:
  TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)
Captures:
  None
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py", line 442, in bind_function_inputs
    bound_arguments = function_type.bind_with_defaults(
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/core/function/polymorphism/function_type.py", line 277, in bind_with_defaults
    with_default_args[arg_name] = constraint.cast(
ValueError: Can not cast 3 to Literal[2]

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1180, in _call_impl
    return self._call_with_structured_signature(args, kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1260, in _call_with_structured_signature
    function_type_utils.canonicalize_function_inputs(
TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (<tf.Tensor: shape=(), dtype=float32, numpy=10.0>,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=<unknown>, dtype=tf.float32, name=None).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1183, in _call_impl
    return self._call_with_flat_signature(args, kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1234, in _call_with_flat_signature
    raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_178944/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_178944/2310937119.py", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (<tf.Tensor: shape=(), dtype=float32, numpy=10.0>,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=<unknown>, dtype=tf.float32, name=None).
Fallback to flat signature also failed due to: pow(a) got unexpected keyword arguments: b.

グラフの取得

それぞれの具象関数は、tf.Graph を囲む呼び出し可能なラッパーです。通常、実際の tf.Graph オブジェクトを取得する必要はないにしろ、具象関数から簡単に取得することが可能です。

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

デバッグ

一般的に、コードのデバックは、tf.function 内で行うよりも、Eager モードで行う方が簡単です。Eager モードでは、tf.function でデコレートする前に、コードがエラーなく実行することを確認しておく必要があります。デバッグプロセスを支援する目的で、tf.config.run_functions_eagerly(True) を呼び出すと、tf.function をグローバルに無効にして、有効にし直すことができます。

tf.function 内でのみ出現する問題を追跡する場合、次のようなヒントがあります。

  • 従来のシンプルな Python print 呼び出しは、トレーシング中にのみ実行されるため、関数が(リ)トレーシングされるときに追跡しやすくなります。
  • tf.print 呼び出しは毎回実行するため、実行中の中間値の追跡に役立ちます。
  • tf.debugging.enable_check_numerics は、NaN と Inf がいつ作成されるかを簡単に追跡できます。
  • pdbPython デバッガ)は、トレーシング中に何が起きているのかを理解する上で役立ちます。(注意: pdb が示すのは、AutoGraph 変換ソースコードです。)

AutoGraph 変換

AutoGraph は、tf.function 内でデフォルトで利用できるようになっているライブラリで、Python の Eager コードのサブセットとグラフ対応の TensorFlow 演算に変換します。これには、ifforwhile などの制御フローが含まれます。

tf.condtf.while_loop などの TensorFlow 演算は機能し続けますが、制御フローは、Python で記述された場合の方が書きやすく理解しやすいことがほとんどです。

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.645081878 0.344026446 0.0305811167 0.918421388 0.0816819668]
[0.568349838 0.331067264 0.0305715837 0.725149751 0.0815007836]
[0.514146328 0.319479436 0.03056206 0.620089 0.0813208073]
[0.473169476 0.309036136 0.0305525456 0.551190078 0.0811420083]
[0.440756589 0.299559951 0.0305430386 0.501411617 0.0809643865]
[0.414271355 0.290909827 0.0305335429 0.463226587 0.0807879269]
[0.39209339 0.282972 0.0305240545 0.43271029 0.0806126148]
[0.373163491 0.275653511 0.0305145755 0.407583863 0.0804384425]
[0.356755644 0.268877536 0.030505104 0.386419207 0.0802654]
[0.342353106 0.262580037 0.0304956455 0.368269116 0.0800934657]
[0.329576522 0.256707132 0.0304861926 0.352476776 0.0799226314]
[0.318140209 0.251213 0.0304767471 0.338570237 0.0797528848]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.30782428, 0.24605857, 0.03046731, 0.32620034, 0.07958422],
      dtype=float32)>

興味があれば、AutoGraph が生成するコードを検査できます。

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

条件文

AutoGraph は if <condition> 文を相当する tf.cond 呼び出しに変換します。この置換は、<condition> がテンソルである場合に行われます。テンソルでない場合は、if 文は Python の条件文として実行されます。

Python 条件文はトレーシング中に実行するため、条件文のブランチが 1 つだけグラフに追加されます。AutoGraph を使用しない場合、データに依存する制御フローが存在すると、トレーシングされたこのグラフは別のブランチを取ることができません。

tf.cond は、条件文の両方のブランチをトレーシングし、実行時に動的に 1 つのブランチを選択してグラフに追加します。トレーシングには意図しない副作用がある場合があります。詳細は、AutoGraph のトレーシング効果をご覧ください。

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

AutoGraph 変換の if 文におけるその他の制約事項については、リファレンスドキュメントをご覧ください。

ループ

AutoGraph は、一部の for 文と while 文を相当する tf.while_loop などの TensorFlow のループ演算に変換します。変換されない場合、for または while ループは Python ループとして実行されます。

この置き換えは、次の場合に行われます。

  • for x in y: y がテンソルである場合、tf.while_loop に変換されます。ytf.data.Dataset である特別なケースでは、tf.data.Dataset 演算の組み合わせが生成されます。
  • while <condition>: <condition> がテンソルである場合、tf.while_loop に変換されます。

Python ループは、トレーシング中に実行され、ループのいてレーションごとに、tf.Graph に追加の演算が追加されます。

TensorFlow ループはループの本体をトレーシングし、実行時に実行する反復回数を動的に選択します。ループ本体は、生成された tf.Graph に一度だけ出現します。

AutoGraph 変換の for 文と while 文におけるその他の制約事項については、リファレンスドキュメントをご覧ください。

Python データのループ

一般的な落とし穴は、tf.function 内で Python/NumPy データをループする際にあります。このループは、トレーシングプロセス中に実行し、ループのイテレーションごとにモデルのコピーを tf.Graph に追加してしまいます。

トレーニングループ全体を tf.function にラップしたいのであれば、データを tf.data.Dataset としてラップし、AutoGraph にトレーニングループを動的に展開させるようにするのが最も安全な方法です。

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<_FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph
train(<_FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph

Python/NumPy データをデータセットにラップする際は、tf.data.Dataset.from_generatortf.data.Dataset.from_tensor_slices の違いに注意してください。前者は、データを Python に維持し、tf.py_function 経由で取得するため、パフォーマンスに問題がありますが、後者は、データのコピーをグラフ内の大型の tf.constant() ノードとしてバンドル化するため、メモリに問題が現れます。

データを消費するには、TFRecordDatasetCsvDataset などを介してファイルからデータを読み取るのが最も効果的な方法です。そうすれば、Python を使わずに、TensorFlow 自体でデータの非同期読み込みとプリフェッチを管理できるようになります。詳細は、「tf.data: TensorFlow 入力パイプラインを構築する」ガイドをご覧ください。

ループでの値の累積

ループの反復ごとに値を累積していくのは一般的なパターンです。通常は、Python のリストに追加したり、Python ディクショナリにエントリを追加したりして行われますが、これらは Python の副作用であるため、動的に展開されるループでは期待どおりに動作しません。動的に展開されるループの結果を累積する場合は、tf.TensorArray を使用してください。

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.84532976, 0.3812127 , 0.66792417, 0.5683036 ],
        [0.9275713 , 0.91194606, 1.1768652 , 1.0822074 ],
        [1.3573077 , 1.9015719 , 1.8221701 , 1.7729309 ]],

       [[0.02114761, 0.6834538 , 0.11220467, 0.7569938 ],
        [0.8682761 , 0.7105073 , 0.92193854, 1.433193  ],
        [1.173238  , 1.036192  , 1.3617182 , 2.1727595 ]]], dtype=float32)>

制限事項

TensorFlow の Function には、設計上、いくつかの制限事項があり、Python 関数を Function に変換する際には、注意が必要です。

Python の副作用の実行

Function 内での出力、リストへのアペンド、グローバル変数のミューテーションといった副作用は、2 回実行されたり、まったく実行しなかったりといったように、予測のつかない動作をすることがあります。また、入力セットで Function を初めて呼び出した場合にのみ実行し、以降では、Python コードを実行せずに、トレーシング済みの tf.Graph が再実行されてしまうこともあります。

基本的に、ロジックでは Python の副作用に依存しないようにし、トレースをデバッグするためだけに使用することをお勧めします。呼び出しごとに TensorFlow ランタイムが確実にコードを実行できるようにするには、tf.datatf.printtf.summarytf.Variable.assigntf.TensorArray などの TensorFlow API を使用するのが最善の方法です。

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

Function の呼び出しごとに Python コードを実行する場合は、tf.py_function が脱出口です。tf.py_function には移植性がなく、特にパフォーマンスに優れているわけでもなく、SavedModel で保存できなければ、分散型(マルチ GPU、TPU)の環境でうまく動作するわけでもありません。また、tf.py_function はグラフに組み込む必要もあるため、すべての入力/出力をテンソルにキャストしてしまいます。

Python のグローバル変数と自由変数の変更

Python のグローバル変数と自由変数の変更は、Python の副作用としてみなされるため、トレーシング中にのみ発生します。

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

場合によっては気づきにくい予期しない動作が発生することがあります。以下の例では、counter は変数のインクリメントを保護することを目的としています。ただし、これは Python 整数であり、TensorFlow オブジェクトではないため、その値は最初のトレース中にキャプチャされます。tf.function を使用すると、assign_add が下のグラフに無条件に記録されます。したがって、v は、tf.function が呼び出されるたびに 1 ずつ増加します。この問題は、tf.function デコレータを使用して Grpah モードの Tensorflow コードを Tensorflow 2 に移行しようとする場合、Python の副作用 (例では counter ) を使用して、実行する演算を決定すると (例では、assign_add )によく発生します。通常、ユーザーは、疑わしい数値結果を確認したり、予想よりもパフォーマンスが大幅に低下した場合に、このことに気付きます(たとえば、保護された演算に非常にコストがかかる場合)。

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3
1
2
3

このような動作を回避し、期待される動作を実現するためには、tf.init_scope を使用して演算を関数グラフの外に移動します。これにより、変数のインクリメントがトレース時間中に 1 回だけ実行されるようになります。init_scope には、制御フローのクリアや勾配テープなどの他の副作用があることに注意してください。init_scope を使用すると非常に複雑になり、現実的に管理できない場合があります。

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1
1
1
1

まとめると、経験則として、Function の外側で機能する整数またはリストのようなコンテナなどの Python オブジェクトのミューテーションは避けてください。代わりに、引数と TF オブジェクトを使用しましょう。たとえば、「ループでの値の累積」セクションには、リストのような演算を実装する方法の一例が示されています。

一部のケースでは、tf.Variable である場合に状態をキャプチャして操作することができます。Keras モデルの重みは、このようにして、同じ ConcreteFunction への呼び出しの繰り返しで更新されています。

Python イテレータとジェネレータの使用

ジェネレータやイテレータなどの多くの Python 機能は、Python ランタイムに依存して状態を追跡しています。一般的に、これらのコンストラクトは Eager モードでも期待どおりに動作しますが、Python の副作用の例であるため、トレーシング中にしか発生しません。

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

TensorFlow にリストコントラクト用の特別な tf.TensorArray があるように、イテレーション用にも特別な tf.data.Iterator があります。概要は、AutoGraph 変換をご覧ください。また、tf.data API を使って、ジェネレータのパターンを実装できます。

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

tf.function のすべての出力は値を返す必要がある

tf.Variableを除いて、tf.function はすべての出力を返す必要があります。戻り値を使用せずに関数からテンソルに直接アクセスしようとすると、「リーク」が発生します。

たとえば、以下の関数は、Python グローバル x を介してテンソル a を「リーク」します。

x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)
3
'SymbolicTensor' object has no attribute 'numpy'

リークされた値も返される場合でもリークします。

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))
2
'SymbolicTensor' object has no attribute 'numpy'
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_178944/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_178944/566849597.py", line 21, in <module>
    captures_leaked_tensor(tf.constant(2))
TypeError: <tf.Tensor 'add:0' shape=() dtype=int32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

<tf.Tensor 'add:0' shape=() dtype=int32> was defined here:
    File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    File "/usr/lib/python3.9/runpy.py", line 87, in _run_code
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 701, in start
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 205, in start
    File "/usr/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
    File "/usr/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
    File "/usr/lib/python3.9/asyncio/events.py", line 80, in _run
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 534, in dispatch_queue
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 523, in process_one
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 429, in dispatch_shell
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 429, in do_execute
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3048, in run_cell
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
    File "/tmpfs/tmp/ipykernel_178944/566849597.py", line 7, in <module>
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 832, in __call__
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 888, in _call
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 695, in _initialize
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 598, in wrapped_fn
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 41, in autograph_handler
    File "/tmpfs/tmp/ipykernel_178944/566849597.py", line 4, in leaky_function
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1478, in binary_op_wrapper
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py", line 1260, in op_dispatch_handler
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1871, in _add_dispatch
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py", line 490, in add_v2
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py", line 796, in _apply_op_helper
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 670, in _create_op_internal
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 2652, in _create_op_internal
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 1160, in from_node_def

The tensor <tf.Tensor 'add:0' shape=() dtype=int32> cannot be accessed from here, because it was defined in FuncGraph(name=leaky_function, id=140038351396288), which is out of scope.

通常、このようなリークは、Python ステートメントまたはデータ構造を使用するときに発生します。アクセスできないテンソルがリークするだけでなく、このようなステートメントは Python の副作用としてカウントされ、すべての関数呼び出しで実行されないことがあるため、間違っている可能性があります。

また、一般的に外部 Python コレクションまたはオブジェクトの変更によりローカルテンソルがリークすることもあります。

class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

再帰的な tf.function はサポートされていない

再帰的な Function はサポートされていないので、無限ループを引き起こす可能性があります。以下に例を示します。

@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
Caught expected exception 
  <class 'Exception'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_178944/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 9, in <module>
    recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:

    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_178944/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/usr/lib/python3.9/abc.py", line 119, in __instancecheck__
        return _abc_instancecheck(cls, instance)

    RecursionError: maximum recursion depth exceeded while calling a Python object

再帰的な Function が正しく動作しているように見えても、Python 関数は複数回トレースされ、パフォーマンスに影響を与える可能性があります。以下に例を示します。

@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

既知の問題

Function が正しく評価していない場合、以下の既知の問題が該当する可能性があります。これらの問題は、今後修正される予定です。

Python のグローバル変数と自由変数への依存

Function は、Python 引数の新しい値で呼び出された時に新しい ConcreteFunction を作成しますが、Python クロージャ、グローバル変数、またはその Function の非ローカル変数に対しては作成しません。Function への呼び出しごとに値が変化する場合でも、Function はトレーシングされたときの値をそのまま使用してしまいます。これは、通常の Python 関数の動作とは異なります。

このため、外側の名前を閉じる代わりに引数を使用する関数プログラミングの様式をお勧めします。

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

グローバル値を更新する別の方法として、それを tf.Variable にし、代わりに Variable.assign メソッドを使用することができます。

@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

Python オブジェクトへの依存

カスタム Python オブジェクトを引数として tf.function に渡すことはサポートされていますが、ある制限が伴います。

特徴量を最大限にカバーするには、オブジェクトを tf.function に渡す前に拡張型に変換することを検討してください。Python プリミティブ型と tf.nest 対応構造も使用できます。

ただし、トレーシングのルールで説明されるように、カスタム TraceType がカスタム Python クラスによって提供されない場合、tf.function はインスタンスベースの等価性を使用するように強制されてしまいます。そのため、変更された属性と同じオブジェクトを渡しても、新しいトレースは作成されません

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

変更されたモデルインスタンスの評価に同じ Function を使用すると、元のモデルと同じインスタンスに基づく TraceTypeがあるために不具合が生じます。

そのため、ミュータブルオブジェクト属性に依存しない Function を書くか、そのような属性を Function に伝達するオブジェクトにトレーシングプロトコルを実装することをお勧めします。

この方法が困難な場合は、回避策として、オブジェクトを変更するたびに新しい Function がリトレーシングを行うようにする方法が挙げられます。

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

リトレーシングにはコストがかかるため、tf.Variable をオブジェクト属性として使用することができます。こうすることで、リトレーシングを行わずに、ミュートして(変更はしません)同様の効果を得ることができます。

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

tf.Variables の作成

Function は、最初の呼び出しで 1 回作成され、後続の関数呼び出しで再利用されるシングルトン tf.Variable のみをサポートします。以下のコードスニペットは、すべての関数呼び出しで新しい tf.Variable を作成します。これにより、ValueError 例外が発生します。

例:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_178944/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_178944/3018268426.py", line 7, in <module>
    f(1.0)
ValueError: in user code:

    File "/tmpfs/tmp/ipykernel_178944/3018268426.py", line 3, in f  *
        v = tf.Variable(1.0)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

この制限を回避するために使用される一般的なパターンは、Python None 値で開始し、値が None の場合は条件付きで tf.Variable を作成することです。

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

複数の Keras オプティマイザとの使用

2 つ以上の Keras オプティマイザを tf.function で使用しようとすると、「ValueError: tf.function only supports singleton tf.Variables created on the first call. 」というエラーが発生することがあります。このエラーは、オプティマイザが初めて勾配を適用する際に、内部的に tf.Variables を作成するために発生するものです。

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705001133.576006  179113 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_178944/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_178944/950644149.py", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    File "/tmpfs/tmp/ipykernel_178944/950644149.py", line 9, in train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 1223, in apply_gradients  **
        return super().apply_gradients(grads_and_vars, name=name)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 638, in apply_gradients
        self.build(trainable_variables)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/adam.py", line 145, in build
        self.add_variable_from_reference(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 1125, in add_variable_from_reference
        return super().add_variable_from_reference(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 513, in add_variable_from_reference
        variable = tf.Variable(

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

トレーニング中にオプティマイザを変更する必要がある場合は、回避策として、オプティマイザごとに新しい Function を作成し、ConcreteFunction を直接呼び出すようにすることができます。

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step)
train_step_2 = tf.function(train_step)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y, opt1)
  else:
    train_step_2(w, x, y, opt2)

複数の Keras モデルとの使用

また、別のモデルインスタンスを同一の Function に渡す際に、「ValueError: tf.function only supports singleton tf.Variables created on the first call.」というエラーも発生することがあります。

このエラーは、Keras モデル(入力形状が定義されていない)と Keras レイヤーが、初めて呼び出されるときに tf.Variables を作成するために発生するものです。これらの変数をすでに呼び出された Function 内で初期化しようとしているのでしょう。このエラーを回避するには、モデルをトレーニングする前に、model.build(input_shape) を呼び出して、すべての重みを初期化するようにしてください。

参考資料

Function のエクスポートと読み込みの方法については、SavedModel ガイドをご覧ください。トレーシングの後に実行するグラフの最適化については、Grappler ガイドをご覧ください。データパイプラインの最適化方法とモデルのプロファイリングについては、Profiler ガイドをご覧ください。