拡張型

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

セットアップ

!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile
2022-12-14 20:32:52.787665: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay

拡張型

ユーザー定義型を使用すると、プロジェクトが読みやすくなり、モジュール化され、保守しやすくなります。ただし、ほとんどの TensorFlow API では、ユーザー定義の Python 型が限定的にしかサポートされていません。これには、高レベル API(Kerastf.functiontf.SavedModel など)と低レベル API(tf.while_looptf.concat など)の両方が含まれます。 TensorFlow 拡張型を使用して、TensorFlow の API とシームレスに連携するユーザー定義のオブジェクト指向型を作成できます。拡張型を作成するには、単純に tf.experimental.ExtensionType をベースとして Python クラスを定義し、型注釈を使用して各フィールドの型を指定します。

class TensorGraph(tf.experimental.ExtensionType):
  """A collection of labeled nodes connected by weighted edges."""
  edge_weights: tf.Tensor               # shape=[num_nodes, num_nodes]
  node_labels: Mapping[str, tf.Tensor]  # shape=[num_nodes]; dtype=any

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for missing/invalid values.

class CSRSparseMatrix(tf.experimental.ExtensionType):
  """Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
  values: tf.Tensor     # shape=[num_nonzero]; dtype=any
  col_index: tf.Tensor  # shape=[num_nonzero]; dtype=int64
  row_index: tf.Tensor  # shape=[num_rows+1]; dtype=int64

tf.experimental.ExtensionType 基底クラスは、標準の Python ライブラリの typing.NamedTuple および @dataclasses.dataclass と同じように機能します。特に、フィールド型の注釈に基づいて、コンストラクタと特別なメソッド(__repr____eq__ など)が自動的に追加されます。

通常、拡張型は次の 2 つのカテゴリのいずれかに分類される傾向があります。

  • データ構造。関連する値のコレクションをグループ化し、それらの値に基づいて役立つ演算を提供できます。データ構造は汎用性が高い場合(上記の TensorGraph の例など)、または特定のモデルに合わせて高度にカスタマイズされている場合があります。

  • テンソルのような型。「テンソル」の概念を特殊化または拡張します。このカテゴリの型には、rankshape、そして通常は dtype があります。 テンソル演算 (tf.stacktf.add、または tf.matmul など)でそれらを使用することは合理的です。MaskedTensorCSRSparseMatrix は、テンソルのような型の例です。

サポートされている API

拡張型は以下の TensorFlow API でサポートされています。

  • Keras: 拡張型は Keras ModelsLayers の入出力として使用できます。
  • tf.data.Dataset: 拡張型は、データセット Datasets に含むことができ、データセット Iterators で返すことができます。
  • TensorFlow Hub: 拡張型は tf.hub の入出力として使用できます。
  • SavedModel: 拡張型は SavedModel 関数の入出力として使用できます。
  • tf.function: 拡張型は、@tf.function デコレータでラップされた関数の引数および戻り値として使用できます。
  • While ループ: 拡張型は tf.while_loop でループ変数として使用でき、while ループの本体の引数および戻り値として使用できます。
  • 条件付き: tf.cond および tf.case を使用して、拡張型を条件付きで選択できます。
  • tf.py_function: 拡張型は引数として使用でき、tf.py_function への func 引数の値を返します。
  • テンソル演算: テンソルの入力(tf.matmultf.gather、および tf.reduce_sum など)を受け入れるほとんどの TensorFlow 演算をサポートするために拡張型を拡張できます。詳細については、以下の「ディスパッチ」セクションに移動してください。
  • 分散ストラテジー: 拡張型はレプリカごとの値として使用できます。

詳細については、以下の「ExtensionTypes をサポートする TensorFlow API」のセクションをご覧ください。

要件

フィールド型

すべてのフィールド(インスタンス変数)を宣言する必要があり、各フィールドに型注釈を指定する必要があります。次の型注釈がサポートされています。

Python 整数 i: int
Python フロート f: float
Python 文字列 s: str
Python ブール値 b: bool
Python None n: None
テンソル形状 shape: tf.TensorShape
テンソル dtype dtype: tf.DType
テンソル t: tf.Tensor
拡張型 mt: MyMaskedTensor
不規則なテンソル rt: tf.RaggedTensor
スパーステンソル st: tf.SparseTensor
インデックススライス s: tf.IndexedSlices
オプションのテンソル o: tf.experimental.Optional
型結合 int_or_float: typing.Union[int, float]
タプル params: typing.Tuple[int, float, tf.Tensor, int]
可変長タプル lengths: typing.Tuple[int, ...]
マッピング tags: typing.Mapping[str, tf.Tensor]
オプションの値 weight: typing.Optional[tf.Tensor]

可変性

拡張型は不変である必要があります。これにより、TensorFlow のグラフトレースメカニズムによって適切に追跡できるようになります。拡張型の値を変更する場合は、代わりに値を変換するメソッドを定義することを検討してください。たとえば、MaskedTensor を変更する set_mask メソッドを定義するのではなく、新しい MaskedTensor を返す replace_mask メソッドを定義できます。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def replace_mask(self, new_mask):
      self.values.shape.assert_is_compatible_with(new_mask.shape)
      return MaskedTensor(self.values, new_mask)

ExtensionType によって追加される機能

ExtensionType 基底クラスは、次の機能を提供します。

  • コンストラクタ(__init__)。
  • 出力可能な表現メソッド(__repr__)。
  • 等価演算子と不等価演算子(__eq__)。
  • 検証メソッド(__validate__)。
  • 不変性の強制。
  • ネストされた TypeSpec
  • テンソル API ディスパッチのサポート。

この機能のカスタマイズの詳細については、以下の「 ExtensionType のカスタマイズ」セクションに移動してください。

コンストラクタ

ExtensionType によって追加されたコンストラクタは、各フィールドを名前付き引数として(クラス定義にリストされている順序で)受け取ります。このコンストラクタは、各パラメーターを型チェックし、必要に応じて変換します。特に、Tensor フィールドは tf.convert_to_tensor を使用して変換されます。 Tuple フィールドは tuple に変換されます。 Mapping フィールドは不変の dict に変換されます。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])

# Fields are type-checked and converted to the declared types.
# For example, `mt.values` is converted to a Tensor.
print(mt.values)
tf.Tensor(
[[1 2 3]
 [4 5 6]], shape=(2, 3), dtype=int32)

フィールド値を宣言された型に変換できない場合、コンストラクタは TypeError を発生させます。

try:
  MaskedTensor([1, 2, 3], None)
except TypeError as e:
  print(f"Got expected TypeError: {e}")
Got expected TypeError: mask: expected a Tensor, got 'NoneType'

フィールドのデフォルト値は、クラスレベルで値を設定することによって指定できます。

class Pencil(tf.experimental.ExtensionType):
  color: str = "black"
  has_erasor: bool = True
  length: tf.Tensor = 1.0

Pencil()
Pencil(color='black', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
Pencil(length=0.5, color="blue")
Pencil(color='blue', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=0.5>)

出力可能な表現

ExtensionType は、クラス名と各フィールドの値を含むデフォルトの出力可能な表現メソッド(__repr__)を追加します。

print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))
MaskedTensor(values=<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, mask=<tf.Tensor: shape=(3,), dtype=bool, numpy=array([ True,  True, False])>)

等値演算子

ExtensionType は、2 つの値が同じ型を持ち、すべてのフィールドが等しい場合に等しいと見なすデフォルトの等価演算子(__eq__ および __ne__)を追加します。テンソルフィールドは、同じ形状を持ち、すべての要素に対して要素ごとに等しい場合、等しいと見なされます。

a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")
a == a: True
a == b: False
a == a.values: False

注意: いずれかのフィールドに Tensor が含まれている場合、__eq__ は(Python ブール値ではなく)スカラーブール値 Tensor を返す場合があります。

検証メソッド

ExtensionType は、フィールドの検証チェックを実行するためにオーバーライドできる __validate__ メソッドを追加します。コンストラクタが呼び出された後、フィールドが型チェックされ、宣言された型に変換された後に実行されるため、すべてのフィールドの型は宣言された型であると想定できます。

次の例では、MaskedTensor を更新して、そのフィールドの shapedtype を検証します。

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor
  def __validate__(self):
    self.values.shape.assert_is_compatible_with(self.mask.shape)
    assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
  MaskedTensor([1, 2, 3], [0, 1, 0])  # Wrong `dtype` for mask.
except AssertionError as e:
  print(f"Got expected AssertionError: {e}")
Got expected AssertionError: mask.dtype must be bool
try:
  MaskedTensor([1, 2, 3], [True, False])  # shapes don't match.
except ValueError as e:
  print(f"Got expected ValueError: {e}")
Got expected ValueError: Shapes (3,) and (2,) are incompatible

不変性の強制

ExtensionType__setattr____delattr__ メソッドをオーバーライドして突然変異を防ぎ、拡張型の値が不変であることを保証します。

mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
  mt.mask = [True, True, True]
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
try:
  mt.mask[0] = False
except TypeError as e:
  print(f"Got expected TypeError: {e}")
Got expected TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
try:
  del mt.mask
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.

ネストされた TypeSpec

ExtensionType クラスには対応する TypeSpec クラスがあり、これは自動的に作成され、<extension_type_name>.Spec として保存されます。

このクラスは、ネストされたテンソルの値以外の値からすべての情報を取得します。特に、値の TypeSpec は、ネストされたテンソル、ExtensionType、または CompositeTensor をその TypeSpec に置き換えることによって作成されます。

class Player(tf.experimental.ExtensionType):
  name: tf.Tensor
  attributes: Mapping[str, tf.Tensor]

anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name)  # Records `dtype` and `shape`, but not the string value.
print(anne_spec.attributes)  # Records keys and TensorSpecs for values.
TensorSpec(shape=(), dtype=tf.string, name=None)
ImmutableDict({'height': TensorSpec(shape=(), dtype=tf.float32, name=None), 'speed': TensorSpec(shape=(), dtype=tf.float32, name=None)})

TypeSpec 値は明示的に構築することも、 tf.type_spec_from_value を使用して ExtensionType 値から構築することもできます。

spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)

TypeSpec は、値を静的コンポーネント動的コンポーネントに分割するために TensorFlow によって使用されます。

  • 静的コンポーネント(グラフ構築時に固定される)は tf.TypeSpec でエンコードされます。
  • 動的コンポーネント(グラフが実行されるたびに変化する可能性があります)は、tf.Tensor のリストとしてエンコードされます。

たとえば、tf.function は、引数に以前は見られなかった TypeSpec があるときはいつでも、そのラップされた関数を再トレースします。

@tf.function
def anonymize_player(player):
  print("<<TRACING>>")
  return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
<<TRACING>>
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.3>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=28.1>}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.1>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=25.3>}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))
<<TRACING>>
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=11.0>, 'jump': <tf.Tensor: shape=(), dtype=float32, numpy=5.3>}))

詳細については、tf.function ガイドをご覧ください。

ExtensionType のカスタマイズ

単純にフィールドとその型を宣言するだけでなく、拡張型は次のことができます。

  • デフォルトの出力可能な表現(__repr__)をオーバーライドします。
  • メソッドを定義します。
  • classmethodstaticmethod を定義します。
  • プロパティを定義します。
  • デフォルトのコンストラクタ(__init__)をオーバーライドします。
  • デフォルトの等価演算子(__eq__)をオーバーライドします。
  • 演算子を定義します(__add____lt__など)。
  • フィールドのデフォルト値を宣言します。
  • サブクラスを定義します。

デフォルトの印刷可能な表現のオーバーライド

拡張型のこのデフォルトの文字列変換演算子をオーバーライドできます。次の例では、MaskedTensor クラスを更新して、値が Eager モードで出力されるときに、より読みやすい文字列表現を生成します。

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for invalid values.

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

def masked_tensor_str(values, mask):
  if isinstance(values, tf.Tensor):
    if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
      return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
    else:
      return f'MaskedTensor(values={values}, mask={mask})'
  if len(values.shape) == 1:
    items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
  else:
    items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
  return '[%s]' % ', '.join(items)

mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])
print(mt)
<MaskedTensor [[1, 2, _], [4, _, 6]]>

メソッドの定義

拡張型は、通常の Python クラスと同様に、メソッドを定義できます。たとえば、MaskedTensor 型は、指定された default 値に置き換えられたマスクされた値を持つ self のコピーを返す with_default メソッドを定義できます。メソッドには、オプションで @tf.function デコレータで注釈を付けることができます。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 0, 3], dtype=int32)>

classmethodstaticmethod の定義

拡張型は、@classmethod および @staticmethod デコレータを使用してメソッドを定義できます。たとえば、MaskedTensor 型は、任意の要素を特定の値でマスクするファクトリメソッドを定義できます。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  @staticmethod
  def from_tensor_and_value_to_mask(values, value_to_mask):
    return MaskedTensor(values, values != value_to_mask)

x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)
<MaskedTensor [[1, _, 2], [3, _, _]]>

プロパティの定義

拡張型は、通常の Python クラスと同様に、@property デコレータを使用してプロパティを定義できます。たとえば、MaskedTensor 型は、値の dtype の短縮形である dtype プロパティを定義できます。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  @property
  def dtype(self):
    return self.values.dtype

MaskedTensor([1, 2, 3], [True, False, True]).dtype
tf.int32

デフォルトのコンストラクタのオーバーライド

拡張型の既定のコンストラクタをオーバーライドできます。カスタムコンストラクタは、宣言されたフィールドごとに値を設定する必要があります。カスタムコンストラクタが戻った後、すべてのフィールドが型チェックされ、値が上記のように変換されます。

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor
  def __init__(self, name, price, discount=0):
    self.name = name
    self.price = price * (1 - discount)

print(Toy("ball", 5.0, discount=0.2))  # On sale -- 20% off!
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)

または、デフォルトのコンストラクタをそのままにして、1 つ以上のファクトリメソッドを追加することも検討できます。例えば、次のとおりです。

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor

  @staticmethod
  def new_toy_with_discount(name, price, discount):
    return Toy(name, price * (1 - discount))

print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)

デフォルトの等価演算子(__eq__)のオーバーライド

拡張型のデフォルトの __eq__ 演算子をオーバーライドできます。次の例では、等しいかどうかを比較するときにマスクされた要素を無視するように MaskedTensor を更新します。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def __eq__(self, other):
    result = tf.math.equal(self.values, other.values)
    result = result | ~(self.mask & other.mask)
    return tf.reduce_all(result)

x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)
tf.Tensor(True, shape=(), dtype=bool)

注意: 通常、__ne__ をオーバーライドする必要はありません。デフォルトの実装では単に __eq__ を呼び出して結果を否定するだけだからです。

前方参照の使用

フィールドの型がまだ定義されていない場合は、代わりに型の名前を含む文字列を使用できます。次の例では、Node 型がまだ(完全に)定義されていないため、文字列 "Node" を使用して children フィールドに注釈を付けています。

class Node(tf.experimental.ExtensionType):
  value: tf.Tensor
  children: Tuple["Node", ...] = ()

Node(3, [Node(5), Node(2)])
Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=3>, children=(Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>, children=()), Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=2>, children=())))

サブクラスの定義

拡張型は、標準の Python 構文を使用してサブクラス化できます。拡張型のサブクラスは、新しいフィールド、メソッド、およびプロパティを追加できます。コンストラクタ、出力可能な表現、および等値演算子をオーバーライドする場合があります。次の例では、3 つの Tensor フィールドを使用してノード間の一連のエッジをエンコードする基本的な TensorGraph クラスを定義します。次に、Tensor フィールドを追加して各ノードの「特徴量値」を記録するサブクラスを定義します。サブクラスは、特徴量値をエッジに沿って伝播するメソッドも定義します。

class TensorGraph(tf.experimental.ExtensionType):
  num_nodes: tf.Tensor
  edge_src: tf.Tensor   # edge_src[e] = index of src node for edge e.
  edge_dst: tf.Tensor   # edge_dst[e] = index of dst node for edge e.

class TensorGraphWithNodeFeature(TensorGraph):
  node_features: tf.Tensor  # node_features[n] = feature value for node n.

  def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
    updates = tf.gather(self.node_features, self.edge_src) * weight
    new_node_features = tf.tensor_scatter_nd_add(
        self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
    return TensorGraphWithNodeFeature(
        self.num_nodes, self.edge_src, self.edge_dst, new_node_features)

g = TensorGraphWithNodeFeature(  # Edges: 0->1, 4->3, 2->2, 2->1
    num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
    node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])

print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)
Original features: tf.Tensor([10.  0.  2.  5. -1.  0.], shape=(6,), dtype=float32)
After propagating: tf.Tensor([10. 12.  4.  4. -1.  0.], shape=(6,), dtype=float32)

プライベートフィールドの定義

拡張型のフィールドは、アンダースコアを(標準の Python 規則に従って)プレフィックスとして付けることにより、非公開としてマークすることができます。これは、TensorFlow がフィールドを処理する方法にはまったく影響しません。これらのフィールドがプライベートであることを拡張型のユーザーに通知するだけです。

ExtensionTypeTypeSpec のカスタマイズ

ExtensionType クラスには対応する TypeSpec クラスがあり、これは自動的に作成され、<extension_type_name>.Spec として保存されます。詳細については、上記の「ネストされた TypeSpec」セクションをご覧ください。

TypeSpec をカスタマイズするには、Spec という名前の独自のネストされたクラスを定義するだけで、ExtensionType は自動的に構築された TypeSpec の基礎としてそれを使用します。次の方法で Spec クラスをカスタマイズできます。

  • デフォルトの出力可能な表現のオーバーライド。
  • デフォルトのコンストラクタのオーバーライド。
  • メソッド、classmethodstaticmethod、およびプロパティの定義。

次の例では、MaskedTensor.Spec クラスをカスタマイズして使いやすくしています。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def with_values(self, new_values):
    return MaskedTensor(new_values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    def __repr__(self):
      return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

注意: カスタム Spec クラスは、元の ExtensionType で宣言されなかったインスタンス変数を使用することはできません。

テンソル API ディスパッチ

拡張型は、tf.Tensor 型によって定義されたインターフェースを特殊化または拡張するという意味で、「テンソルのような」ものにすることができます。テンソルのような拡張型の例には、RaggedTensorSparseTensor、および MaskedTensor が含まれます。ディスパッチデコレータは、テンソルのような拡張型に適用された場合に、TensorFlow 演算のデフォルトの動作をオーバーライドするために使用できます。 TensorFlow は現在、3 つのディスパッチデコレータを定義しています。

単一の API のディスパッチ

tf.experimental.dispatch_for_api デコレータは、指定されたシグネチャで呼び出されると、指定された TensorFlow 演算のデフォルトの動作をオーバーライドします。たとえば、このデコレータを使用して、tf.stackMaskedTensor値を処理する方法を指定できます。

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))

これは、MaskedTensor 値のリストで呼び出されるたびに、tf.stack のデフォルトの実装をオーバーライドします(values 引数には、typing.List[MaskedTensor] で注釈が付けられているためです)。

x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])
<MaskedTensor [[1, 2, _], [_, 5, 6]]>

tf.stack が混在した MaskedTensor 値と Tensor 値のリストを処理できるようにするには、values パラメータの型注釈を設定し直し、関数の本体を適切に更新します。

tf.experimental.unregister_dispatch_for(masked_stack)

def convert_to_masked_tensor(x):
  if isinstance(x, MaskedTensor):
    return x
  else:
    return MaskedTensor(x, tf.ones_like(x, tf.bool))

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
  values = [convert_to_masked_tensor(v) for v in values]
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])
<MaskedTensor [[1, 2, _], [4, 5, 6], [1, 2, _]]>

オーバーライドできる API のリストについては、tf.experimental.dispatch_for_api の API ドキュメントをご覧ください。

すべての単項要素ごとの API のディスパッチ

tf.experimental.dispatch_for_unary_elementwise_apis デコレータは、最初の引数(通常は x という名前)の値が型注釈 x_type と一致する場合はいつでも、すべての単項要素ごとの演算(tf.math.cos など)のデフォルトの動作をオーバーライドします。装飾された関数は、次の 2 つの引数を取る必要があります。

  • api_func: 単一のパラメータを取り、要素ごとの演算を実行する関数(たとえば、tf.abs)。
  • x: 要素ごとの演算の最初の引数。

次の例では、MaskedTensor 型を処理するためにすべての単項要素ごとの演算を更新します。

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
 def masked_tensor_unary_elementwise_api_handler(api_func, x):
   return MaskedTensor(api_func(x.values), x.mask)

MaskedTensor で単項要素ごとの演算が呼び出されるたびに、この関数が使用されるようになりました。

x = MaskedTensor([1, -2, -3], [True, False, True])
 print(tf.abs(x))
<MaskedTensor [1, _, 3]>
print(tf.ones_like(x, dtype=tf.float32))
<MaskedTensor [1.0, _, 1.0]>

バイナリのすべての要素ごとの API のディスパッチ

同様に、tf.experimental.dispatch_for_binary_elementwise_apis を使用して、MaskedTensor 型を処理するためにすべてのバイナリ要素ごとの演算を更新できます。

@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
  return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)
<MaskedTensor [[5, _, 1], [_, _, _]]>

オーバーライドされる要素ごとの API のリストについては、tf.experimental.dispatch_for_unary_elementwise_apis および tf.experimental.dispatch_for_binary_elementwise_apis の API ドキュメントをご覧ください。

バッチ処理可能な ExtensionType

1 つのインスタンスを使用して値のバッチを表すことができる場合、ExtensionTypeバッチ可能です。通常、これはネストされたすべての Tensor にバッチディメンションを追加することによって実現されます。次の TensorFlow API では、拡張型の入力がバッチ可能である必要があります。

デフォルトでは、BatchableExtensionType は、ネストされた TensorCompositeTensor、およびExtensionType をバッチ処理することにより、バッチ処理された値を作成します。これがクラスに適していない場合は、tf.experimental.ExtensionTypeBatchEncoder を使用してこのデフォルトの動作をオーバーライドする必要があります。例えば、個々のスパーステンソルの valuesindices、および dense_shape を単純にスタックしてtf.SparseTensor 値のバッチを作成することは適切ではありません。ほとんどの場合、これらのテンソルの形状には互換性がないため、スタックできません。たとえできたとしても、結果は有効な SparseTensor にはなりません。

注意: BatchableExtensionTypeは、tf.stacktf.concattf.slice などのディスパッチャを自動的に定義しません。クラスをこれらの API でサポートする必要がある場合は、上記のディスパッチデコレータを使用してください。

BatchableExtensionType の例: Network

例として、負荷分散に使用される単純な Network クラスを考えてみましょう。これは、各ノードで実行するために残っている作業の量と、ノード間で作業を移動するために使用できる帯域幅を追跡します。

class Network(tf.experimental.ExtensionType):  # This version is not batchable.
  work: tf.Tensor       # work[n] = work left to do at node n
  bandwidth: tf.Tensor  # bandwidth[n1, n2] = bandwidth from n1->n2

net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])

この型をバッチ処理可能にするには、ベースタイプを BatchableExtensionType に変更し、各フィールドの形状を調整してオプションのバッチの次元を含めます。次の例では、バッチ形状を追跡するための shape フィールドも追加します。この shape フィールドは tf.data.Dataset または tf.map_fn では必要ありませんが、tf.keras では必要です

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape. A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)

def network_repr(network):
  work = network.work
  bandwidth = network.bandwidth
  if hasattr(work, 'numpy'):
    work = ' '.join(str(work.numpy()).split())
  if hasattr(bandwidth, 'numpy'):
    bandwidth = ' '.join(str(bandwidth.numpy()).split())
  return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
    work=tf.stack([net1.work, net2.work]),
    bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")
net1=<Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]>
net2=<Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>
batch=<Network shape=(2,) work=[[5. 3. 8.] [3. 4. 2.]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>

その後、tf.data.Dataset を使用して、ネットワークのバッチを反復処理できます。

dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
  print(f"Batch element {i}: {network}")
Batch element 0: <Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]>
Batch element 1: <Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>

また、map_fn を使用して、関数を各バッチ要素に適用することもできます。

def balance_work_greedy(network):
  delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
  delta /= 4
  delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
  new_work = network.work + tf.reduce_sum(delta, -1)
  return Network(new_work, network.bandwidth)

tf.map_fn(balance_work_greedy, batch_of_networks)
<Network shape=(2,) work=[[5.5 1.25 9.25] [3. 4.75 1.25]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>

ExtensionType をサポートする TensorFlow API

@tf.function

tf.function は、Python 関数の TensorFlow グラフを事前計算するデコレータで、TensorFlow コードのパフォーマンスを大幅に改善できます。拡張型の値は、@tf.function でデコレートされた関数で透過的に使用できます。

class Pastry(tf.experimental.ExtensionType):
  sweetness: tf.Tensor  # 2d embedding that encodes sweetness
  chewiness: tf.Tensor  # 2d embedding that encodes chewiness

@tf.function
def combine_pastry_features(x: Pastry):
  return (x.sweetness + x.chewiness) / 2

cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>

tf.functioninput_signature を明示的に指定する場合は、拡張型の TypeSpec を使用して指定できます。

pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))

@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
  return Pastry(x.sweetness + delta, x.chewiness)

increase_sweetness(cookie)
Pastry(sweetness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([2.2, 1.4], dtype=float32)>, chewiness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.8, 0.2], dtype=float32)>)

具象関数

具象関数は、tf.function で構築された個別のトレース済みグラフをカプセル化します。拡張型は、具象関数で透過的に使用できます。

cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>

制御フロー演算

拡張型は以下の TensorFlow の制御フロー演算でサポートされています。

# Example: using tf.cond to select between two MaskedTensors. Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
<MaskedTensor [1.0, _, 3.0]>
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
  return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])
<MaskedTensor [26.285717, 37.285698, 112.285736, _]>

Autograph 制御フロー

拡張型は、tf.function の制御フローステートメントでもサポートされます(autograph を使用)。次の例では、if ステートメントと for ステートメントは、拡張型をサポートする tf.cond および tf.while_loop 演算に自動的に変換されます。

@tf.function
def fn(x, b):
  if b:
    x = MaskedTensor(x, tf.less(x, 0))
  else:
    x = MaskedTensor(x, tf.greater(x, 0))
  for i in tf.range(5 if b else 7):
    x = x.with_values(x.values + 1 / 2)
  return x

print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))
<MaskedTensor [_, 0.5, _]>
<MaskedTensor [4.5, _, 6.5]>

Keras

tf.keras は、ディープラーニングモデルを構築およびトレーニングするための TensorFlow の高レベル API です。拡張型は、入力として Keras モデルに渡され、Keras レイヤー間で渡され、Keras モデルによって返されます。 Keras は現在、拡張型に対して 2 つの要件を課しています。

  • バッチ可能でなければならない(上記の「バッチ処理可能な ExtensionType」に移動してください)。
  • shape という名前のフィールドまたはプロパティが必要である。shape[0] はバッチの次元と見なされます。

次の 2 つのサブセクションでは、Keras で拡張型を使用する方法についての例を示します。

Keras の例: Network

最初の例として、上記の「バッチ処理可能な ExtensionType」セクションで定義された Network クラスを考えてみましょう。これは、ノード間の作業の負荷分散に使用できます。ここではその定義が繰り返されています。

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape. A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)
single_network = Network(  # A single network with 4 nodes.
    work=[8.0, 5, 12, 2],
    bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])

batch_of_networks = Network(  # Batch of 2 networks, each w/ 2 nodes.
    work=[[8.0, 5], [3, 2]],
    bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])

Network を処理する新しい Keras レイヤーを定義できます。

class BalanceNetworkLayer(tf.keras.layers.Layer):
  """Layer that balances work between nodes in a network.

  Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
  """
  def call(self, inputs):
    # This function is defined above in the "Batchable `ExtensionType`s" section.
    return balance_work_greedy(inputs)

次に、これらのレイヤーを使用して単純なモデルを作成できます。ExtensionType をモデルにフィードするには、type_spec が拡張型の TypeSpec に設定された tf.keras.layer.Input レイヤーを使用できます。Keras モデルをバッチで使用する場合、type_spec をバッチの次元に含める必要があります。

input_spec = Network.Spec(shape=None,
                          work=tf.TensorSpec(None, tf.float32),
                          bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    BalanceNetworkLayer(),
    ])

最後に、モデルを単一のネットワークとネットワークのバッチに適用できます。

model(single_network)
<Network shape=() work=[ 9.25 5. 14. -1.25] bandwidth=[[0. 1. 2. 2.] [1. 0. 0. 2.] [2. 0. 0. 1.] [2. 2. 1. 0.]]>
model(batch_of_networks)
<Network shape=(2,) work=[[8.75 4.25] [3.25 1.75]] bandwidth=[[[0. 1.] [1. 0.]] [[0. 2.] [2. 0.]]]>

Keras の例: MaskedTensor

この例では、MaskedTensorKeras をサポートするように拡張されています。shape は、values フィールドから計算されるプロパティとして定義されます。Keras では、このプロパティを拡張型とその TypeSpec の両方に追加する必要があります。MaskedTensor__name__ 変数も定義します。これは、SavedModel のシリアル化(下記)に必要です。

class MaskedTensor(tf.experimental.BatchableExtensionType):
  # __name__ is required for serialization in SavedModel; see below for details.
  __name__ = 'extension_type_colab.MaskedTensor'

  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

    def with_shape(self):
      return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
                               tf.TensorSpec(shape, self.mask.dtype))

次に、ディスパッチデコレータを使用して、いくつかの TensorFlow API のデフォルトの動作をオーバーライドします。これらの API は標準の Keras レイヤー(Dense レイヤーなど)で使用されるため、これらをオーバーライドすると、これらのレイヤーを MaskedTensor で使用できるようになります。この例では、マスクされたテンソルの matmul は、マスクされた値をゼロとして扱う(つまり、それらを積に含めない)ように定義されています。

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
 return MaskedTensor(op(x.values), x.mask)

@tf.experimental.dispatch_for_binary_elementwise_apis(
    Union[MaskedTensor, tf.Tensor],
    Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
  x = convert_to_masked_tensor(x)
  y = convert_to_masked_tensor(y)
  return MaskedTensor(op(x.values, y.values), x.mask & y.mask)

@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
                  transpose_a=False, transpose_b=False,
                  adjoint_a=False, adjoint_b=False,
                  a_is_sparse=False, b_is_sparse=False,
                  output_type=None):
  if isinstance(a, MaskedTensor):
    a = a.with_default(0)
  if isinstance(b, MaskedTensor):
    b = b.with_default(0)
  return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
                   adjoint_b, a_is_sparse, b_is_sparse, output_type)

次に、標準の Keras レイヤーを使用して、MaskedTensor 入力を受け入れる Kerasモデルを構築できます。

input_spec = MaskedTensor.Spec([None, 2], tf.float32)

masked_tensor_model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    tf.keras.layers.Dense(16, activation="relu"),
    tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
                  [[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))
Epoch 1/3
1/1 [==============================] - 1s 1s/step - loss: 0.7110
Epoch 2/3
1/1 [==============================] - 0s 6ms/step - loss: 0.6215
Epoch 3/3
1/1 [==============================] - 0s 5ms/step - loss: 0.5670
tf.Tensor(
[[ 0.20307031]
 [-0.32614586]
 [ 1.0157624 ]], shape=(3, 1), dtype=float32)

SavedModel

SavedModel は、シリアル化された TensorFlow プログラムで、重みと計算の両方が含まれます。Keras モデルまたはカスタムモデルから構築できます。いずれの場合でも、拡張型は SavedModel によって定義された関数とメソッドで透過的に使用できます。

SavedModel は、拡張型に __name__ フィールドがある限り、拡張型を処理するモデル、レイヤー、および関数を保存できます。この名前は拡張型を登録するために使用されるため、モデルを読み込む際に見つけることができます。

例: Keras モデルを保存する

拡張型を使用する Keras モデルは、SavedModel を使用して保存できます。

masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp3_ax5e_0/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp3_ax5e_0/assets
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[ 0.20307031],
       [-0.32614586],
       [ 1.0157624 ]], dtype=float32)>

例: カスタムモデルを保存する

SavedModel は、拡張型を処理する関数を持つカスタム tf.Module サブクラスを保存するためにも使用できます。

class CustomModule(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def grow(self, x: MaskedTensor):
    """Increase values in `x` by multiplying them by `self.v`."""
    return MaskedTensor(x.values * self.v, x.mask)

module = CustomModule(100.0)

module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
                                                    dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp9apon9h2/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp9apon9h2/assets
<MaskedTensor [_, 200.0, _]>

ExtensionType が利用できない場合に SavedModel を読み込む

ExtensionType を使用する SavedModel を読み込んだけれども、その ExtensionType が利用できない(つまり、インポートされていない)場合、警告が表示され、TensorFlow は「匿名拡張型」オブジェクトの使用にフォールバックします。このオブジェクトには元の型と同じフィールドがありますが、カスタムメソッドやプロパティなど、型に追加したカスタマイズはありません。

TensorFlow Serving で ExtensionType を使用する

現在、TensorFlow Serving(および SavedModel の「シグネチャ」ディクショナリの他のコンシューマー)は、すべての入力と出力が生のテンソルである必要があります。拡張型を使用するモデルで TensorFlow Serving を使用する場合は、テンソルから拡張型の値を構成または分解するラッパーメソッドを追加できます。例えば、次のとおりです。

class CustomModuleWrapper(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def var_weighted_mean(self, x: MaskedTensor):
    """Mean value of unmasked values in x, weighted by self.v."""
    x = MaskedTensor(x.values * self.v, x.mask)
    return (tf.reduce_sum(x.with_default(0)) /
            tf.reduce_sum(tf.cast(x.mask, x.dtype)))

  @tf.function()
  def var_weighted_mean_wrapper(self, x_values, x_mask):
    """Raw tensor wrapper for var_weighted_mean."""
    return self.var_weighted_mean(MaskedTensor(x_values, x_mask))

module = CustomModuleWrapper([3., 2., 8., 5.])

module.var_weighted_mean_wrapper.get_concrete_function(
    tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpts_yhd88/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpts_yhd88/assets
<tf.Tensor: shape=(), dtype=float32, numpy=12.0>

Dataset

tf.data は、単純で再利用可能なピースから複雑な入力パイプラインを構築できる API です。そのコアデータ構造は tf.data.Dataset で、一連の要素を表し、その各要素には 1 つ以上のコンポーネントが含まれます。

拡張型を使用した Dataset の構築

Dataset.from_tensorsDataset.from_tensor_slices、または Dataset.from_generator を使用して、拡張型の値からデータセットを構築できます。

ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
Pastry(sweetness=<tf.Tensor: shape=(), dtype=int32, numpy=5>, chewiness=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
  print(value)
<MaskedTensor [0, 1, 2, 3]>
<MaskedTensor [4, 5, 6, 7]>
<MaskedTensor [8, 9, 10, 11]>
<MaskedTensor [12, 13, 14, 15]>
<MaskedTensor [16, 17, 18, 19]>
def value_gen():
  for i in range(2, 7):
    yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])

ds = tf.data.Dataset.from_generator(
    value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
  print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]>
<MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]>
<MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]>
<MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]>
<MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>

拡張型を使用した Dataset のバッチ処理とバッチ処理解除

拡張型を持つデータセットは、Dataset.batch および Dataset.unbatch を使用してバッチおよびバッチ解除できます。

batched_ds = ds.batch(2)
for value in batched_ds:
  print(value)
<MaskedTensor [[_, 1, _, 3, _, 5, _, 7, _, 9], [_, 1, 2, _, 4, 5, _, 7, 8, _]]>
<MaskedTensor [[_, 1, 2, 3, _, 5, 6, 7, _, 9], [_, 1, 2, 3, 4, _, 6, 7, 8, 9]]>
<MaskedTensor [[_, 1, 2, 3, 4, 5, _, 7, 8, 9]]>
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
  print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]>
<MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]>
<MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]>
<MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]>
<MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>