在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 Github 上查看源代码 | 下载笔记本 |
安装
!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile
2022-12-14 20:14:08.129645: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay
扩展程序类型
用户定义的类型可以使项目的可读性、模块化、可维护程度更高。但是,大多数 TensorFlow API 对于用户定义的 Python 类型的支持却非常有限。这包括高级 API(如 Keras、tf.function、tf.SavedModel)和低级 API(如 tf.while_loop
和 tf.concat
)。TensorFlow 扩展程序类型可用于创建能够与 TensorFlow 的 API 无缝协作的用户定义的面向对象类型。要创建扩展程序类型,只需定义一个以 tf.experimental.ExtensionType
为基础的 Python 类,并使用类型注解来指定每个字段的类型。
class TensorGraph(tf.experimental.ExtensionType):
"""A collection of labeled nodes connected by weighted edges."""
edge_weights: tf.Tensor # shape=[num_nodes, num_nodes]
node_labels: Mapping[str, tf.Tensor] # shape=[num_nodes]; dtype=any
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for missing/invalid values.
class CSRSparseMatrix(tf.experimental.ExtensionType):
"""Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
values: tf.Tensor # shape=[num_nonzero]; dtype=any
col_index: tf.Tensor # shape=[num_nonzero]; dtype=int64
row_index: tf.Tensor # shape=[num_rows+1]; dtype=int64
tf.experimental.ExtensionType
基类的工作方式类似于标准 Python 库中的 typing.NamedTuple
和 @dataclasses.dataclass
。特别是,它会根据字段类型注解自动添加构造函数和特殊方法(例如 __repr__
和 __eq__
)。
通常,扩展程序类型往往属于以下两个类别之一:
数据结构,会将一组相关的值组合在一起,并且可以基于这些值提供有用的运算。数据结构可以十分常规(例如上面的
TensorGraph
示例),也可以针对特定模型进行高度定制。类张量类型,限定或延伸了“张量”的概念。此类别中的类型具有
rank
、shape
,通常还有dtype
;并且将它们与张量运算(例如tf.stack
、tf.add
或tf.matmul
)一起使用是合理的。MaskedTensor
和CSRSparseMatrix
是类张量类型的示例。
支持的 API
以下 TensorFlow API 支持扩展程序类型:
- Keras:扩展程序类型可以用作 Keras
Models
和Layers
的输入和输出。 - tf.data.Dataset:扩展程序类型可以包含在
Datasets
中,并由数据集Iterators
返回。 - TensorFlow Hub:扩展程序类型可以用作
tf.hub
模块的输入和输出。 - SavedModel:扩展程序类型可以用作
SavedModel
函数的输入和输出。 - tf.function:扩展程序类型可以用作使用
@tf.function
装饰器包装的函数的参数和返回值。 - While 循环:扩展程序类型可以用作
tf.while_loop
中的循环变量,也可以用作 while 循环体的参数和返回值。 - 条件:可以使用
tf.cond
和tf.case
有条件地选择扩展程序类型。 tf.py_function
:扩展程序类型可以用作tf.py_function
的参数以及针对func
参数的返回值。- 张量运算:扩展程序类型可扩展以支持大多数接受张量输入的 TensorFlow 运算(例如,
tf.matmul
、tf.gather
和tf.reduce_sum
)。如需了解详情,请转到下面的调度部分。 - 分布策略:扩展程序类型可以用作按副本值。
有关详情,请参阅下面的“支持 ExtensionType 的 TensorFlow API”部分。
要求
字段类型
必须声明所有字段(实例变量),并且必须为每个字段提供类型注解。支持以下类型注解:
类型 | 示例 |
---|---|
Python 整数 | i: int |
Python 浮点数 | f: float |
Python 字符串 | s: str |
Python 布尔值 | b: bool |
Python None | n: None |
张量形状 | shape: tf.TensorShape |
张量数据类型 | dtype: tf.DType |
张量 | t: tf.Tensor |
扩展程序类型 | mt: MyMaskedTensor |
不规则张量 | rt: tf.RaggedTensor |
稀疏张量 | st: tf.SparseTensor |
索引切片 | s: tf.IndexedSlices |
可选张量 | o: tf.experimental.Optional |
类型联合 | int_or_float: typing.Union[int, float] |
元组 | params: typing.Tuple[int, float, tf.Tensor, int] |
可变长度元组 | lengths: typing.Tuple[int, ...] |
映射 | tags: typing.Mapping[str, tf.Tensor] |
可选值 | weight: typing.Optional[tf.Tensor] |
可变性
扩展程序类型必须是不可变的。这可以确保它们能够被 TensorFlow 的计算图跟踪机制正确跟踪。如果您发现自己想要改变扩展程序类型值,请考虑改为定义用于转换值的方法。例如,与其定义 set_mask
方法来改变 MaskedTensor
,您可以定义用于返回新的 MaskedTensor
的 set_mask
方法:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def replace_mask(self, new_mask):
self.values.shape.assert_is_compatible_with(new_mask.shape)
return MaskedTensor(self.values, new_mask)
ExtensionType
添加的功能
ExtensionType
基类提供了以下功能:
- 构造函数 (
__init__
)。 - 可打印表示方法 (
__repr__
)。 - 相等和不等运算符 (
__eq__
)。 - 验证方法 (
__validate__
)。 - 强制不变性。
- 嵌套
TypeSpec
。 - 张量 API 调度支持。
有关自定义此功能的更多信息,请转到下面的“自定义 ExtensionType
”部分。
构造函数
ExtensionType
添加的构造函数会将每个字段作为命名参数(按照它们在类定义中的排列顺序)。此构造函数将对每个形参进行类型检查,并在必要时对其进行转换。特别是,Tensor
字段会使用 tf.convert_to_tensor
进行转换;Tuple
字段会被转换为 tuple
;Mapping
字段会被转换为不可变字典。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
# Fields are type-checked and converted to the declared types.
# For example, `mt.values` is converted to a Tensor.
print(mt.values)
tf.Tensor( [[1 2 3] [4 5 6]], shape=(2, 3), dtype=int32)
如果字段值无法转换为其声明的类型,构造函数将引发 TypeError
:
try:
MaskedTensor([1, 2, 3], None)
except TypeError as e:
print(f"Got expected TypeError: {e}")
Got expected TypeError: mask: expected a Tensor, got 'NoneType'
可以通过在类级别设置字段的值来指定字段的默认值:
class Pencil(tf.experimental.ExtensionType):
color: str = "black"
has_erasor: bool = True
length: tf.Tensor = 1.0
Pencil()
Pencil(color='black', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
Pencil(length=0.5, color="blue")
Pencil(color='blue', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=0.5>)
可打印表示
ExtensionType
添加了一个默认的可打印表示方法 (__repr__
),其中包括类名和每个字段的值:
print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))
MaskedTensor(values=<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, mask=<tf.Tensor: shape=(3,), dtype=bool, numpy=array([ True, True, False])>)
相等运算符
ExtensionType
添加了默认相等运算符 (__eq__
和 __ne__
),如果两个值具有相同的类型并且其所有字段都相等,则认为二者相等。如果张量字段具有相同的形状并且对所有元素均符合逐元素相等,则认为张量字段相等。
a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")
a == a: True a == b: False a == a.values: False
注:如果任何字段包含 Tensor
,则 __eq__
可能会返回标量布尔 Tensor
(而非 Python 布尔值)。
验证方法
ExtensionType
添加了一个 __validate__
方法,此方法可重写以对字段执行验证检查。它会在调用构造函数之后,以及在字段经过类型检查并转换为其声明的类型之后运行,因此它可以假定所有字段都具有其声明的类型。
以下示例会更新 MaskedTensor
以验证其字段的 shape
和 dtype
:
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor
def __validate__(self):
self.values.shape.assert_is_compatible_with(self.mask.shape)
assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
MaskedTensor([1, 2, 3], [0, 1, 0]) # Wrong `dtype` for mask.
except AssertionError as e:
print(f"Got expected AssertionError: {e}")
Got expected AssertionError: mask.dtype must be bool
try:
MaskedTensor([1, 2, 3], [True, False]) # shapes don't match.
except ValueError as e:
print(f"Got expected ValueError: {e}")
Got expected ValueError: Shapes (3,) and (2,) are incompatible
强制不变性
ExtensionType
会重写 __setattr__
和 __delattr__
方法以防止变更,从而确保扩展程序类型值不可变。
mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
mt.mask = [True, True, True]
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
try:
mt.mask[0] = False
except TypeError as e:
print(f"Got expected TypeError: {e}")
Got expected TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
try:
del mt.mask
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
嵌套 TypeSpec
每个 ExtensionType
类都有一个对应的 TypeSpec
类,它会自动创建并存储为 <extension_type_name>.Spec
。
此类会从值中捕获所有信息,除了任何嵌套张量的值。特别是,值的 TypeSpec
是通过将任何嵌套张量、ExtensionType 或 CompositeTensor 替换为其 TypeSpec
来创建的。
class Player(tf.experimental.ExtensionType):
name: tf.Tensor
attributes: Mapping[str, tf.Tensor]
anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name) # Records `dtype` and `shape`, but not the string value.
print(anne_spec.attributes) # Records keys and TensorSpecs for values.
TensorSpec(shape=(), dtype=tf.string, name=None) ImmutableDict({'height': TensorSpec(shape=(), dtype=tf.float32, name=None), 'speed': TensorSpec(shape=(), dtype=tf.float32, name=None)})
TypeSpec
值可以显式构造,也可以使用 tf.type_spec_from_value
从 ExtensionType
值构造:
spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)
TensorFlow 会使用 TypeSpec
将值划分为静态组件和动态组件:
- 静态组件(在计算图构建时固定不变)使用
tf.TypeSpec
进行编码。 - 动态组件(每次运行计算图时都会发生变化)被编码为
tf.Tensor
的列表。
例如,每当参数具有以前未见过的 TypeSpec
时,tf.function
都会回溯它的包装函数:
@tf.function
def anonymize_player(player):
print("<<TRACING>>")
return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
<<TRACING>> Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.3>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=28.1>}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.1>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=25.3>}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))
<<TRACING>> Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=11.0>, 'jump': <tf.Tensor: shape=(), dtype=float32, numpy=5.3>}))
有关详情,请参阅 tf.function 指南。
自定义 ExtensionType
除了简单地声明字段及其类型外,扩展程序类型还可以:
- 重写默认可打印表示 (
__repr__
)。 - 定义方法。
- 定义类方法和静态方法。
- 定义属性。
- 重写默认构造函数 (
__init__
)。 - 重写默认相等运算符 (
__eq__
)。 - 定义运算符(例如
__add__
和__lt__
)。 - 声明字段的默认值。
- 定义子类。
重写默认可打印表示
您可以为扩展程序类型重写此默认字符串转换运算符。以下示例会更新 MaskedTensor
类以在 Eager 模式下打印值时生成更具可读性的字符串表示。
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for invalid values.
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def masked_tensor_str(values, mask):
if isinstance(values, tf.Tensor):
if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
else:
return f'MaskedTensor(values={values}, mask={mask})'
if len(values.shape) == 1:
items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
else:
items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
return '[%s]' % ', '.join(items)
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
print(mt)
<MaskedTensor [[1, 2, _], [4, _, 6]]>
定义方法
与任何常规 Python 类一样,扩展程序类型也可以定义方法。例如,MaskedTensor
类型可以定义 with_default
方法,该方法会返回一个 self
的副本,其中掩码值会被替换为给定的 default
值。可以选择使用 @tf.function
装饰器注解方法。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def with_default(self, default):
return tf.where(self.mask, self.values, default)
MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 0, 3], dtype=int32)>
定义类方法和静态方法
扩展程序类型可以使用 @classmethod
和 @staticmethod
装饰器定义方法。例如,MaskedTensor
类型可以定义能够使用给定值来遮盖任何元素的工厂方法:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
@staticmethod
def from_tensor_and_value_to_mask(values, value_to_mask):
return MaskedTensor(values, values != value_to_mask)
x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)
<MaskedTensor [[1, _, 2], [3, _, _]]>
定义属性
与任何常规 Python 类一样,扩展程序类型也可以使用 @property
装饰器定义属性。例如,MaskedTensor
类型可以定义 dtype
属性,它是值的数据类型的简写形式:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
@property
def dtype(self):
return self.values.dtype
MaskedTensor([1, 2, 3], [True, False, True]).dtype
tf.int32
重写默认构造函数
您可以重写扩展程序类型的默认构造函数。自定义构造函数必须为每个声明的字段均设置一个值;并且在自定义构造函数返回后,所有字段都将进行类型检查,并将按上述方式转换值。
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
def __init__(self, name, price, discount=0):
self.name = name
self.price = price * (1 - discount)
print(Toy("ball", 5.0, discount=0.2)) # On sale -- 20% off!
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)
或者,您可以考虑保留默认构造函数,但添加一个或多个工厂方法。例如:
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
@staticmethod
def new_toy_with_discount(name, price, discount):
return Toy(name, price * (1 - discount))
print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)
重写默认相等运算符 (__eq__
)
您可以重写扩展程序类型的默认 __eq__
运算符。以下示例会更新 MaskedTensor
以在比较相等性时忽略遮盖元素。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def __eq__(self, other):
result = tf.math.equal(self.values, other.values)
result = result | ~(self.mask & other.mask)
return tf.reduce_all(result)
x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)
tf.Tensor(True, shape=(), dtype=bool)
注:您通常不需要重写 __ne__
,因为其默认实现只需调用 __eq__
并对结果求反。
使用前向引用
如果字段的类型尚未定义,您可以改用包含类型名称的字符串。在以下示例中,字符串 "Node"
用于注解 children
字段,因为 Node
类型尚未(完全)定义。
class Node(tf.experimental.ExtensionType):
value: tf.Tensor
children: Tuple["Node", ...] = ()
Node(3, [Node(5), Node(2)])
Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=3>, children=(Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>, children=()), Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=2>, children=())))
定义子类
扩展程序类型可以使用标准 Python 语法进行子类化。扩展程序类型子类可以添加新字段、方法和属性;并且可以重写构造函数、可打印表示和相等运算符。以下示例定义了一个基本的 TensorGraph
类,使用三个 Tensor
字段来编码节点之间的一组边。然后,它会定义一个子类,添加一个 Tensor
字段来记录每个节点的“特征值”。该子类还会定义一个沿着边传播特征值的方法。
class TensorGraph(tf.experimental.ExtensionType):
num_nodes: tf.Tensor
edge_src: tf.Tensor # edge_src[e] = index of src node for edge e.
edge_dst: tf.Tensor # edge_dst[e] = index of dst node for edge e.
class TensorGraphWithNodeFeature(TensorGraph):
node_features: tf.Tensor # node_features[n] = feature value for node n.
def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
updates = tf.gather(self.node_features, self.edge_src) * weight
new_node_features = tf.tensor_scatter_nd_add(
self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
return TensorGraphWithNodeFeature(
self.num_nodes, self.edge_src, self.edge_dst, new_node_features)
g = TensorGraphWithNodeFeature( # Edges: 0->1, 4->3, 2->2, 2->1
num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])
print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)
Original features: tf.Tensor([10. 0. 2. 5. -1. 0.], shape=(6,), dtype=float32) After propagating: tf.Tensor([10. 12. 4. 4. -1. 0.], shape=(6,), dtype=float32)
定义私有字段
扩展程序类型的字段可以通过在前面加上下划线来标记为私有(遵循标准 Python 惯例)。这不会影响 TensorFlow 处理字段的任何方式;但只为向扩展程序类型的任何用户表明这些字段为私有。
自定义 ExtensionType 的 TypeSpec
每个 ExtensionType
类都有一个对应的 TypeSpec
类,后者是自动创建的并被存储为 <extension_type_name>.Spec
。有关详情,请参阅上面的“嵌套 TypeSpec”部分。
要自定义 TypeSpec
,只需定义您自己的名为 Spec
的嵌套类,ExtensionType
将使用它作为自动构造的 TypeSpec
的基础。您可以通过以下方式自定义 Spec
类:
- 重写默认可打印表示。
- 重写默认构造函数。
- 定义方法、类方法、静态方法和属性。
以下示例自定义了 MaskedTensor.Spec
类以使其更加易于使用:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def with_values(self, new_values):
return MaskedTensor(new_values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
def __repr__(self):
return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
注:自定义 Spec
类不能使用任何未在原始 ExtensionType
中声明的实例变量。
张量 API 调度
扩展程序类型可以是“类张量”,因为它们限定或延伸了 tf.Tensor
类型定义的接口。类张量扩展程序类型的示例包括 RaggedTensor
、SparseTensor
和 MaskedTensor
。当应用于类张量扩展程序类型时,调度装饰器可用于重写 TensorFlow 运算的默认行为。TensorFlow 目前定义了三个调度装饰器:
@tf.experimental.dispatch_for_api(tf_api)
@tf.experimental.dispatch_for_unary_elementwise_apis(x_type)
@tf.experimental.dispatch_for_binary_elementwise_apis(x_type, y_type)
单个 API 的调度
在使用指定签名进行调用时,tf.experimental.dispatch_for_api
装饰器会重写指定 TensorFlow 运算的默认行为。例如,您可以使用此装饰器来指定 tf.stack
应如何处理 MaskedTensor
值:
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
每当使用 MaskedTensor
值的列表调用tf.stack
时,这都会重写它的默认实现(因为 values
参数使用 typing.List[MaskedTensor]
注解):
x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])
<MaskedTensor [[1, 2, _], [_, 5, 6]]>
要允许 tf.stack
处理混合的 MaskedTensor
和 Tensor
值的列表,您可以优化 values
形参的类型注解并适当地更新函数体:
tf.experimental.unregister_dispatch_for(masked_stack)
def convert_to_masked_tensor(x):
if isinstance(x, MaskedTensor):
return x
else:
return MaskedTensor(x, tf.ones_like(x, tf.bool))
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
values = [convert_to_masked_tensor(v) for v in values]
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])
<MaskedTensor [[1, 2, _], [4, 5, 6], [1, 2, _]]>
有关可重写 API 的列表,请参阅 tf.experimental.dispatch_for_api
的 API 文档。
所有一元逐元素 API 的调度
只要第一个参数(通常命名为 x
)的值与类型注解 x_type
相匹配,tf.experimental.dispatch_for_unary_elementwise_apis
装饰器就会重写所有一元逐元素运算(例如 tf.math.cos
)的默认行为。装饰函数应接受两个参数:
api_func
:接受单个形参并执行逐元素运算的函数(例如tf.abs
)。x
:逐元素运算的第一个参数。
以下示例会更新所有一元逐元素运算以处理 MaskedTensor
类型:
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def masked_tensor_unary_elementwise_api_handler(api_func, x):
return MaskedTensor(api_func(x.values), x.mask)
现在,只要在 MaskedTensor
上调用一元逐元素运算,就会使用此函数。
x = MaskedTensor([1, -2, -3], [True, False, True])
print(tf.abs(x))
<MaskedTensor [1, _, 3]>
print(tf.ones_like(x, dtype=tf.float32))
<MaskedTensor [1.0, _, 1.0]>
所有二进制逐元素 API 的调度
同样,tf.experimental.dispatch_for_binary_elementwise_apis
可用于更新所有二进制逐元素运算以处理 MaskedTensor
类型:
@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)
<MaskedTensor [[5, _, 1], [_, _, _]]>
有关被重写的逐元素 API 的列表,请转到 tf.experimental.dispatch_for_unary_elementwise_apis
和 tf.experimental.dispatch_for_binary_elementwise_apis
的 API 文档。
可批处理 ExtensionType
如果单个实例可用于表示一批值,则 ExtensionType
为可批处理。通常,这可以通过向所有嵌套 Tensor
添加批量维度来实现。以下 TensorFlow API 要求任何扩展程序类型的输入都可批处理:
tf.data.Dataset
(batch
、unbatch
、from_tensor_slices
)tf.keras
(fit
、evaluate
、predict
)tf.map_fn
默认情况下,BatchableExtensionType
会通过批处理任何嵌套的 Tensor
、CompositeTensor
和 ExtensionType
来创建批处理值。如果这不适合您的类,那么您将需要使用 tf.experimental.ExtensionTypeBatchEncoder
来重写此默认行为。例如,通过简单地堆叠各个稀疏张量的 values
、indices
和 dense_shape
字段来创建一批 tf.SparseTensor
值是不合适的 – 在大多数情况下,您不能堆叠这些张量,因为它们具有不兼容的形状;即便可以,结果也不会是有效的 SparseTensor
。
注:BatchableExtensionType
不会自动为 tf.stack
、tf.concat
、tf.slice
等定义调度器。如果您的类需要这些 API 的支持,请使用上述调度装饰器。
BatchableExtensionType 示例:Network
例如,请思考用于负载均衡的简单 Network
类,用于跟踪每个节点还有多少剩余工作,以及有多少带宽可用于在节点之间移动工作:
class Network(tf.experimental.ExtensionType): # This version is not batchable.
work: tf.Tensor # work[n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[n1, n2] = bandwidth from n1->n2
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
要使此类型可批处理,请将基本类型更改为 BatchableExtensionType
,并调整每个字段的形状来包含可选的批次维度。以下示例还添加了一个 shape
字段来跟踪批次形状。tf.data.Dataset
或 tf.map_fn
不需要此 shape
字段,但 tf.keras
需要。
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
def network_repr(network):
work = network.work
bandwidth = network.bandwidth
if hasattr(work, 'numpy'):
work = ' '.join(str(work.numpy()).split())
if hasattr(bandwidth, 'numpy'):
bandwidth = ' '.join(str(bandwidth.numpy()).split())
return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
work=tf.stack([net1.work, net2.work]),
bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")
net1=<Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]> net2=<Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]> batch=<Network shape=(2,) work=[[5. 3. 8.] [3. 4. 2.]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>
然后,您可以使用 tf.data.Dataset
迭代一批网络:
dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
print(f"Batch element {i}: {network}")
Batch element 0: <Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]> Batch element 1: <Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>
您还可以使用 map_fn
对每个批处理元素应用函数:
def balance_work_greedy(network):
delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
delta /= 4
delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
new_work = network.work + tf.reduce_sum(delta, -1)
return Network(new_work, network.bandwidth)
tf.map_fn(balance_work_greedy, batch_of_networks)
<Network shape=(2,) work=[[5.5 1.25 9.25] [3. 4.75 1.25]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>
支持 ExtensionType 的 TensorFlow API
@tf.function
tf.function 是预计算 Python 函数 TensorFlow 计算图的装饰器,可以大幅改善 TensorFlow 代码的性能。扩展程序类型能够透明地与 @tf.function
装饰的函数一起使用。
class Pastry(tf.experimental.ExtensionType):
sweetness: tf.Tensor # 2d embedding that encodes sweetness
chewiness: tf.Tensor # 2d embedding that encodes chewiness
@tf.function
def combine_pastry_features(x: Pastry):
return (x.sweetness + x.chewiness) / 2
cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>
如果您希望为 tf.function
明确指定 input_signature
,则可以使用扩展程序类型的 TypeSpec
执行此操作。
pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))
@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
return Pastry(x.sweetness + delta, x.chewiness)
increase_sweetness(cookie)
Pastry(sweetness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([2.2, 1.4], dtype=float32)>, chewiness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.8, 0.2], dtype=float32)>)
具体函数
具体函数封装通过 tf.function
构建的各个跟踪计算图。扩展程序类型可以透明地与具体函数一起使用。
cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>
控制流运算
TensorFlow 的控制流运算支持扩展程序类型:
# Example: using tf.cond to select between two MaskedTensors. Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
<MaskedTensor [1.0, _, 3.0]>
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])
<MaskedTensor [26.285717, 37.285698, 112.285736, _]>
Autograph 控制流
tf.function 中的控制流语句也支持扩展程序类型(使用 autograph)。在以下示例中,if
语句和 for
语句会自动转换为支持扩展程序类型的 tf.cond
和 tf.while_loop
运算。
@tf.function
def fn(x, b):
if b:
x = MaskedTensor(x, tf.less(x, 0))
else:
x = MaskedTensor(x, tf.greater(x, 0))
for i in tf.range(5 if b else 7):
x = x.with_values(x.values + 1 / 2)
return x
print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))
<MaskedTensor [_, 0.5, _]> <MaskedTensor [4.5, _, 6.5]>
Keras
tf.keras 是 TensorFlow 用于构建和训练深度学习模型的高级 API。扩展程序类型可以作为输入传递给 Keras 模型,在 Keras 层之间传递,并由 Keras 模型返回。Keras 目前对扩展程序类型具有两项要求:
- 它们必须可批处理(请转到上面的“可批处理
ExtensionType
”)。 - 它们必须具有名为
shape
的字段或属性。假定shape[0]
为批次维度。
以下两个小节提供了展示如何将扩展程序类型与 Keras 一起使用的示例。
Keras 示例:Network
对于第一个示例,请思考上面“可批处理 ExtensionType”部分定义的 Network
类,它可以用于节点之间的负载均衡工作。这里再次给出它的定义:
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
single_network = Network( # A single network with 4 nodes.
work=[8.0, 5, 12, 2],
bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])
batch_of_networks = Network( # Batch of 2 networks, each w/ 2 nodes.
work=[[8.0, 5], [3, 2]],
bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])
您可以定义用于处理 Network
的新 Keras 层。
class BalanceNetworkLayer(tf.keras.layers.Layer):
"""Layer that balances work between nodes in a network.
Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
"""
def call(self, inputs):
# This function is defined above in the "Batchable `ExtensionType`s" section.
return balance_work_greedy(inputs)
然后,您可以使用这些层来创建一个简单的模型。要将 ExtensionType
馈送给模型,您可以使用 tf.keras.layer.Input
层并将 type_spec
设置为扩展程序类型的 TypeSpec
。如果 Keras 模型将用于处理批次,那么 type_spec
必须包含批次维度。
input_spec = Network.Spec(shape=None,
work=tf.TensorSpec(None, tf.float32),
bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
BalanceNetworkLayer(),
])
最后,您可以将模型应用于单个网络和一批网络。
model(single_network)
<Network shape=() work=[ 9.25 5. 14. -1.25] bandwidth=[[0. 1. 2. 2.] [1. 0. 0. 2.] [2. 0. 0. 1.] [2. 2. 1. 0.]]>
model(batch_of_networks)
<Network shape=(2,) work=[[8.75 4.25] [3.25 1.75]] bandwidth=[[[0. 1.] [1. 0.]] [[0. 2.] [2. 0.]]]>
Keras 示例:MaskedTensor
在此示例中,MaskedTensor
进行了扩展以支持 Keras
。shape
定义为从 values
字段计算的属性。Keras 要求您将此属性添加到扩展程序类型及其 TypeSpec
。MaskedTensor
还定义了 SavedModel
序列化所需的 __name__
变量(如下)。
class MaskedTensor(tf.experimental.BatchableExtensionType):
# __name__ is required for serialization in SavedModel; see below for details.
__name__ = 'extension_type_colab.MaskedTensor'
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_default(self, default):
return tf.where(self.mask, self.values, default)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_shape(self):
return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
tf.TensorSpec(shape, self.mask.dtype))
接下来,调度装饰器会用于重写多个 TensorFlow API 的默认行为。由于这些 API 会由标准 Keras 层(例如 Dense
层)使用,对其进行重写,我们就能够将这些层与 MaskedTensor
一起使用。出于本示例的目的,我们定义了掩码张量的 matmul
以将掩码值视为零(即,不将它们包含在乘积中)。
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
return MaskedTensor(op(x.values), x.mask)
@tf.experimental.dispatch_for_binary_elementwise_apis(
Union[MaskedTensor, tf.Tensor],
Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
x = convert_to_masked_tensor(x)
y = convert_to_masked_tensor(y)
return MaskedTensor(op(x.values, y.values), x.mask & y.mask)
@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
transpose_a=False, transpose_b=False,
adjoint_a=False, adjoint_b=False,
a_is_sparse=False, b_is_sparse=False,
output_type=None):
if isinstance(a, MaskedTensor):
a = a.with_default(0)
if isinstance(b, MaskedTensor):
b = b.with_default(0)
return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
adjoint_b, a_is_sparse, b_is_sparse, output_type)
然后,您可以使用标准 Keras 层构建一个接受 MaskedTensor
输入的 Keras 模型:
input_spec = MaskedTensor.Spec([None, 2], tf.float32)
masked_tensor_model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
tf.keras.layers.Dense(16, activation="relu"),
tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
[[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))
Epoch 1/3 1/1 [==============================] - 3s 3s/step - loss: 0.6819 Epoch 2/3 1/1 [==============================] - 0s 5ms/step - loss: 0.6239 Epoch 3/3 1/1 [==============================] - 0s 5ms/step - loss: 0.5903 tf.Tensor( [[ 0.18340722] [-0.08917451] [ 1.3972318 ]], shape=(3, 1), dtype=float32)
SavedModel
SavedModel 是序列化 TensorFlow 程序,包括权重和计算。它可以通过 Keras 模型或自定义模型构建。在任何一种情况下,扩展程序类型都可以透明地与 SavedModel 定义的函数和方法一起使用。
SavedModel 可以保存用于处理扩展程序类型的模型、层和函数,只要扩展程序类型具有 __name__
字段即可。此名称用于注册扩展程序类型,以便在加载模型时进行定位。
示例:保存 Keras 模型
可以使用 SavedModel
来保存使用扩展程序类型的 Keras 模型。
masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpkj9uvl1_/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpkj9uvl1_/assets <tf.Tensor: shape=(3, 1), dtype=float32, numpy= array([[ 0.18340722], [-0.08917451], [ 1.3972318 ]], dtype=float32)>
示例:保存自定义模型
SavedModel 还可用于保存包含用于处理扩展程序类型的函数的自定义 tf.Module
子类。
class CustomModule(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def grow(self, x: MaskedTensor):
"""Increase values in `x` by multiplying them by `self.v`."""
return MaskedTensor(x.values * self.v, x.mask)
module = CustomModule(100.0)
module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpq2zpskdx/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpq2zpskdx/assets <MaskedTensor [_, 200.0, _]>
在 ExtensionType 不可用时加载 SavedModel
如果您加载使用 ExtensionType
的 SavedModel
,但该 ExtensionType
不可用(即尚未导入),您将看到一条警告,并且 TensorFlow 将回退到使用“匿名扩展程序类型”对象。此对象将具有与原始类型相同的字段,但将缺少您为该类型添加的任何后续自定义内容,例如自定义方法或属性。
ExtensionType
与 TensorFlow Serving 一起使用
目前,TensorFlow Serving(以及 SavedModel“签名”字典的其他使用者)要求所有输入和输出都是原始张量。如果您希望将 TensorFlow Serving 与使用扩展程序类型的模型一起使用,可以添加用于组合或分解张量的扩展程序类型值的封装容器方法。 例如:
class CustomModuleWrapper(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def var_weighted_mean(self, x: MaskedTensor):
"""Mean value of unmasked values in x, weighted by self.v."""
x = MaskedTensor(x.values * self.v, x.mask)
return (tf.reduce_sum(x.with_default(0)) /
tf.reduce_sum(tf.cast(x.mask, x.dtype)))
@tf.function()
def var_weighted_mean_wrapper(self, x_values, x_mask):
"""Raw tensor wrapper for var_weighted_mean."""
return self.var_weighted_mean(MaskedTensor(x_values, x_mask))
module = CustomModuleWrapper([3., 2., 8., 5.])
module.var_weighted_mean_wrapper.get_concrete_function(
tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpbv2z853b/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpbv2z853b/assets <tf.Tensor: shape=(), dtype=float32, numpy=12.0>
数据集
tf.data 是一个 API,可用于通过简单的可重用代码块构建复杂的输入流水线。它的核心数据结构是 tf.data.Dataset
,表示一系列元素,每个元素包含一个或多个分量。
使用扩展程序类型构建数据集
可以使用 Dataset.from_tensors
、Dataset.from_tensor_slices
或 Dataset.from_generator
从扩展程序类型值构建数据集:
ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
Pastry(sweetness=<tf.Tensor: shape=(), dtype=int32, numpy=5>, chewiness=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
print(value)
<MaskedTensor [0, 1, 2, 3]> <MaskedTensor [4, 5, 6, 7]> <MaskedTensor [8, 9, 10, 11]> <MaskedTensor [12, 13, 14, 15]> <MaskedTensor [16, 17, 18, 19]>
def value_gen():
for i in range(2, 7):
yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])
ds = tf.data.Dataset.from_generator(
value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]> <MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]> <MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]> <MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]> <MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>
使用扩展程序类型批处理和取消批处理数据集
可以使用 Dataset.batch
和 Dataset.unbatch
对具有扩展程序类型的数据集进行批处理和取消批处理。
batched_ds = ds.batch(2)
for value in batched_ds:
print(value)
<MaskedTensor [[_, 1, _, 3, _, 5, _, 7, _, 9], [_, 1, 2, _, 4, 5, _, 7, 8, _]]> <MaskedTensor [[_, 1, 2, 3, _, 5, 6, 7, _, 9], [_, 1, 2, 3, 4, _, 6, 7, 8, 9]]> <MaskedTensor [[_, 1, 2, 3, 4, 5, _, 7, 8, 9]]>
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]> <MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]> <MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]> <MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]> <MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>