Ver en TensorFlow.org | Ejecutar en Google Colab | Ver fuente en GitHub | Descargar libreta |
Configuración
!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile
Tipos de extensiones
Los tipos definidos por el usuario pueden hacer que los proyectos sean más legibles, modulares y mantenibles. Sin embargo, la mayoría de las API de TensorFlow tienen una compatibilidad muy limitada con los tipos de Python definidos por el usuario. Esto incluye API de alto nivel (como Keras , tf.function , tf.SavedModel ) y API de nivel inferior (como tf.while_loop
y tf.concat
). Los tipos de extensión de TensorFlow se pueden usar para crear tipos orientados a objetos definidos por el usuario que funcionan sin problemas con las API de TensorFlow. Para crear un tipo de extensión, simplemente defina una clase de Python con tf.experimental.ExtensionType
como base y use anotaciones de tipo para especificar el tipo de cada campo.
class TensorGraph(tf.experimental.ExtensionType):
"""A collection of labeled nodes connected by weighted edges."""
edge_weights: tf.Tensor # shape=[num_nodes, num_nodes]
node_labels: Mapping[str, tf.Tensor] # shape=[num_nodes]; dtype=any
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for missing/invalid values.
class CSRSparseMatrix(tf.experimental.ExtensionType):
"""Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
values: tf.Tensor # shape=[num_nonzero]; dtype=any
col_index: tf.Tensor # shape=[num_nonzero]; dtype=int64
row_index: tf.Tensor # shape=[num_rows+1]; dtype=int64
La clase base tf.experimental.ExtensionType
funciona de manera similar a typing.NamedTuple
y @dataclasses.dataclass
de la biblioteca estándar de Python. En particular, agrega automáticamente un constructor y métodos especiales (como __repr__
y __eq__
) en función de las anotaciones de tipo de campo.
Por lo general, los tipos de extensión tienden a caer en una de dos categorías:
Estructuras de datos , que agrupan una colección de valores relacionados y pueden proporcionar operaciones útiles basadas en esos valores. Las estructuras de datos pueden ser bastante generales (como el ejemplo anterior de
TensorGraph
); o pueden personalizarse mucho para un modelo específico.Tipos similares a tensores , que especializan o amplían el concepto de "tensor". Los tipos de esta categoría tienen un
rank
, unashape
y, por lo general, undtype
; y tiene sentido usarlos con operaciones Tensor (comotf.stack
,tf.add
otf.matmul
).MaskedTensor
yCSRSparseMatrix
son ejemplos de tipos similares a tensores.
API compatibles
Los tipos de extensión son compatibles con las siguientes API de TensorFlow:
- Keras : los tipos de extensión se pueden usar como entradas y salidas para los
Models
yLayers
de Keras. - tf.data.Dataset : los tipos de extensión se pueden incluir en
Datasets
de datos y devolverlos losIterators
de conjuntos de datos. - Tensorflow hub : los tipos de extensión se pueden usar como entradas y salidas para los módulos
tf.hub
. - Modelo guardado : los tipos de extensión se pueden usar como entradas y salidas para funciones de
SavedModel
. - tf.function : los tipos de extensión se pueden usar como argumentos y devolver valores para funciones envueltas con el decorador
@tf.function
. - while loops : los tipos de extensión se pueden usar como variables de bucle en
tf.while_loop
, y se pueden usar como argumentos y valores de retorno para el cuerpo del while-loop. - condicionales : los tipos de extensión se pueden seleccionar condicionalmente usando
tf.cond
ytf.case
. - py_function : los tipos de extensión se pueden usar como argumentos y devolver valores para el argumento
func
atf.py_function
. - Tensor ops : los tipos de extensión se pueden ampliar para admitir la mayoría de las operaciones de TensorFlow que aceptan entradas de Tensor (p. ej.,
tf.matmul
,tf.gather
ytf.reduce_sum
). Consulte la sección " Despacho " a continuación para obtener más información. - estrategia de distribución : los tipos de extensión se pueden usar como valores por réplica.
Para obtener más detalles, consulte la sección sobre "API de TensorFlow que admiten ExtensionTypes" a continuación.
Requisitos
tipos de campo
Se deben declarar todos los campos (también conocidos como variables de instancia) y se debe proporcionar una anotación de tipo para cada campo. Se admiten las siguientes anotaciones de tipo:
Escribe | Ejemplo |
---|---|
enteros de Python | i: int |
pitón flota | f: float |
cadenas de pitón | s: str |
Booleanos de Python | b: bool |
Python Ninguno | n: None |
Formas de tensor | shape: tf.TensorShape |
Tipos de tensor | dtype: tf.DType |
tensores | t: tf.Tensor |
Tipos de extensiones | mt: MyMaskedTensor |
Tensores irregulares | rt: tf.RaggedTensor |
Tensores dispersos | st: tf.SparseTensor |
Sectores indexados | s: tf.IndexedSlices |
Tensores opcionales | o: tf.experimental.Optional |
Tipo de uniones | int_or_float: typing.Union[int, float] |
tuplas | params: typing.Tuple[int, float, tf.Tensor, int] |
tuplas de longitud Var | lengths: typing.Tuple[int, ...] |
Asignaciones | tags: typing.Mapping[str, tf.Tensor] |
Valores opcionales | weight: typing.Optional[tf.Tensor] |
Mutabilidad
Los tipos de extensión deben ser inmutables. Esto garantiza que los mecanismos de seguimiento de gráficos de TensorFlow puedan rastrearlos correctamente. Si desea mutar un valor de tipo de extensión, considere definir métodos que transformen valores. Por ejemplo, en lugar de definir un método set_mask
para mutar un MaskedTensor
, podría definir un método replace_mask
que devuelva un nuevo MaskedTensor
:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def replace_mask(self, new_mask):
self.values.shape.assert_is_compatible_with(new_mask.shape)
return MaskedTensor(self.values, new_mask)
Funcionalidad añadida por ExtensionType
La clase base ExtensionType
proporciona la siguiente funcionalidad:
- Un constructor (
__init__
). - Un método de representación imprimible (
__repr__
). - Operadores de igualdad y desigualdad (
__eq__
). - Un método de validación (
__validate__
). - Inmutabilidad forzosa.
- Un
TypeSpec
anidado. - Soporte de envío de API de tensor.
Consulte la sección "Personalizar tipos de extensión" a continuación para obtener más información sobre cómo personalizar esta funcionalidad.
Constructor
El constructor agregado por ExtensionType
toma cada campo como un argumento con nombre (en el orden en que se enumeraron en la definición de clase). Este constructor verificará el tipo de cada parámetro y los convertirá cuando sea necesario. En particular, los campos Tensor
se convierten usando tf.convert_to_tensor
; Los campos de Tuple
se convierten en tuple
s; y los campos de Mapping
se convierten en dictados inmutables.
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
# Fields are type-checked and converted to the declared types.
# E.g., mt.values is converted to a Tensor.
print(mt.values)
tf.Tensor( [[1 2 3] [4 5 6]], shape=(2, 3), dtype=int32)
El constructor genera un TypeError
si un valor de campo no se puede convertir a su tipo declarado:
try:
MaskedTensor([1, 2, 3], None)
except TypeError as e:
print(f"Got expected TypeError: {e}")
Got expected TypeError: mask: expected a Tensor, got None
El valor predeterminado para un campo se puede especificar estableciendo su valor en el nivel de clase:
class Pencil(tf.experimental.ExtensionType):
color: str = "black"
has_erasor: bool = True
length: tf.Tensor = 1.0
Pencil()
Pencil(color='black', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
Pencil(length=0.5, color="blue")
Pencil(color='blue', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=0.5>)
Representación imprimible
ExtensionType
agrega un método de representación imprimible predeterminado ( __repr__
) que incluye el nombre de la clase y el valor de cada campo:
print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))
MaskedTensor(values=<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, mask=<tf.Tensor: shape=(3,), dtype=bool, numpy=array([ True, True, False])>)
Operadores de igualdad
ExtensionType
agrega operadores de igualdad predeterminados ( __eq__
y __ne__
) que consideran dos valores iguales si tienen el mismo tipo y todos sus campos son iguales. Los campos tensoriales se consideran iguales si tienen la misma forma y son iguales para todos los elementos.
a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")
a == a: True a == b: False a == a.values: False
Método de validación
ExtensionType
agrega un método __validate__
, que se puede anular para realizar comprobaciones de validación en los campos. Se ejecuta después de llamar al constructor y después de que los campos se hayan verificado y convertido a sus tipos declarados, por lo que puede asumir que todos los campos tienen sus tipos declarados.
El siguiente ejemplo actualiza MaskedTensor
para validar las shape
y dtype
de sus campos:
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor
def __validate__(self):
self.values.shape.assert_is_compatible_with(self.mask.shape)
assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
MaskedTensor([1, 2, 3], [0, 1, 0]) # wrong dtype for mask.
except AssertionError as e:
print(f"Got expected AssertionError: {e}")
Got expected AssertionError: mask.dtype must be bool
try:
MaskedTensor([1, 2, 3], [True, False]) # shapes don't match.
except ValueError as e:
print(f"Got expected ValueError: {e}")
Got expected ValueError: Shapes (3,) and (2,) are incompatible
inmutabilidad forzada
ExtensionType
anula los métodos __setattr__
y __delattr__
para evitar la mutación, lo que garantiza que los valores de tipo de extensión sean inmutables.
mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
mt.mask = [True, True, True]
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
try:
mt.mask[0] = False
except TypeError as e:
print(f"Got expected TypeError: {e}")
Got expected TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
try:
del mt.mask
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
Especificación de tipo anidado
Cada clase ExtensionType
tiene una clase TypeSpec
correspondiente, que se crea automáticamente y se almacena como <extension_type_name>.Spec
.
Esta clase captura toda la información de un valor excepto los valores de cualquier tensor anidado. En particular, el TypeSpec
para un valor se crea reemplazando cualquier Tensor, ExtensionType o CompositeTensor anidado con su TypeSpec
.
class Player(tf.experimental.ExtensionType):
name: tf.Tensor
attributes: Mapping[str, tf.Tensor]
anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name) # Records dtype and shape, but not the string value.
print(anne_spec.attributes) # Records keys and TensorSpecs for values.
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'> TensorSpec(shape=(), dtype=tf.string, name=None) ImmutableDict({'height': TensorSpec(shape=(), dtype=tf.float32, name=None), 'speed': TensorSpec(shape=(), dtype=tf.float32, name=None)})
Los valores TypeSpec
se pueden construir explícitamente o se pueden construir a partir de un valor ExtensionType
usando tf.type_spec_from_value
:
spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)
TensorFlow usa TypeSpec
para dividir los valores en un componente estático y un componente dinámico :
- El componente estático (que se fija en el momento de la construcción del gráfico) se codifica con
tf.TypeSpec
. - El componente dinámico (que puede variar cada vez que se ejecuta el gráfico) se codifica como una lista de
tf.Tensor
s.
Por ejemplo, tf.function
vuelve sobre su función envuelta cada vez que un argumento tiene un TypeSpec
nunca antes visto:
@tf.function
def anonymize_player(player):
print("<<TRACING>>")
return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'> <<TRACING>> Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.3>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=28.1>}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.1>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=25.3>}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))
<<TRACING>> Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=11.0>, 'jump': <tf.Tensor: shape=(), dtype=float32, numpy=5.3>}))
Para obtener más información, consulte la Guía de funciones tf .
Personalización de tipos de extensión
Además de simplemente declarar campos y sus tipos, los tipos de extensión pueden:
- Anule la representación imprimible predeterminada (
__repr__
). - Definir métodos.
- Defina métodos de clase y métodos estáticos.
- Definir propiedades.
- Anule el constructor predeterminado (
__init__
). - Anule el operador de igualdad predeterminado (
__eq__
). - Defina operadores (como
__add__
y__lt__
). - Declarar valores predeterminados para los campos.
- Definir subclases.
Anular la representación imprimible predeterminada
Puede anular este operador de conversión de cadena predeterminado para los tipos de extensión. El siguiente ejemplo actualiza la clase MaskedTensor
para generar una representación de cadena más legible cuando los valores se imprimen en modo Eager.
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for invalid values.
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def masked_tensor_str(values, mask):
if isinstance(values, tf.Tensor):
if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
else:
return f'MaskedTensor(values={values}, mask={mask})'
if len(values.shape) == 1:
items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
else:
items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
return '[%s]' % ', '.join(items)
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
print(mt)
<MaskedTensor [[1, 2, _], [4, _, 6]]>
Definición de métodos
Los tipos de extensión pueden definir métodos, como cualquier clase normal de Python. Por ejemplo, el tipo MaskedTensor
podría definir un método with_default
que devuelve una copia de self
con valores enmascarados reemplazados por un valor default
determinado. Los métodos se pueden anotar opcionalmente con el decorador @tf.function
.
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def with_default(self, default):
return tf.where(self.mask, self.values, default)
MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 0, 3], dtype=int32)>
Definición de métodos de clase y métodos estáticos
Los tipos de extensión pueden definir métodos utilizando los decoradores @classmethod
y @staticmethod
. Por ejemplo, el tipo MaskedTensor
podría definir un método de fábrica que enmascare cualquier elemento con un valor dado:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
@staticmethod
def from_tensor_and_value_to_mask(values, value_to_mask):
return MaskedTensor(values, values == value_to_mask)
x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)
<MaskedTensor [[_, 0, _], [_, 0, 0]]>
Definición de propiedades
Los tipos de extensión pueden definir propiedades utilizando el decorador @property
, como cualquier clase normal de Python. Por ejemplo, el tipo MaskedTensor
podría definir una propiedad dtype
que es una abreviatura del dtype de los valores:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
@property
def dtype(self):
return self.values.dtype
MaskedTensor([1, 2, 3], [True, False, True]).dtype
tf.int32
Anulando el constructor predeterminado
Puede anular el constructor predeterminado para los tipos de extensión. Los constructores personalizados deben establecer un valor para cada campo declarado; y después de que regrese el constructor personalizado, se verificará el tipo de todos los campos y los valores se convertirán como se describe anteriormente.
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
def __init__(self, name, price, discount=0):
self.name = name
self.price = price * (1 - discount)
print(Toy("ball", 5.0, discount=0.2)) # On sale -- 20% off!
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)
Alternativamente, puede considerar dejar el constructor predeterminado como está, pero agregar uno o más métodos de fábrica. P.ej:
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
@staticmethod
def new_toy_with_discount(name, price, discount):
return Toy(name, price * (1 - discount))
print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)
Anulando el operador de igualdad predeterminado ( __eq__
)
Puede anular el operador __eq__
predeterminado para los tipos de extensión. El siguiente ejemplo actualiza MaskedTensor
para ignorar los elementos enmascarados al comparar la igualdad.
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def __eq__(self, other):
result = tf.math.equal(self.values, other.values)
result = result | ~(self.mask & other.mask)
return tf.reduce_all(result)
x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)
tf.Tensor(True, shape=(), dtype=bool)
Uso de referencias directas
Si el tipo de un campo aún no se ha definido, puede usar una cadena que contenga el nombre del tipo en su lugar. En el siguiente ejemplo, la cadena "Node"
se usa para anotar el campo children
porque el tipo de Node
aún no se ha definido (totalmente).
class Node(tf.experimental.ExtensionType):
value: tf.Tensor
children: Tuple["Node", ...] = ()
Node(3, [Node(5), Node(2)])
Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=3>, children=(Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>, children=()), Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=2>, children=())))
Definición de subclases
Los tipos de extensión se pueden subclasificar utilizando la sintaxis estándar de Python. Las subclases de tipo de extensión pueden agregar nuevos campos, métodos y propiedades; y puede anular el constructor, la representación imprimible y el operador de igualdad. El siguiente ejemplo define una clase TensorGraph
básica que usa tres campos Tensor
para codificar un conjunto de bordes entre nodos. Luego define una subclase que agrega un campo Tensor
para registrar un "valor de característica" para cada nodo. La subclase también define un método para propagar los valores de las características a lo largo de los bordes.
class TensorGraph(tf.experimental.ExtensionType):
num_nodes: tf.Tensor
edge_src: tf.Tensor # edge_src[e] = index of src node for edge e.
edge_dst: tf.Tensor # edge_dst[e] = index of dst node for edge e.
class TensorGraphWithNodeFeature(TensorGraph):
node_features: tf.Tensor # node_features[n] = feature value for node n.
def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
updates = tf.gather(self.node_features, self.edge_src) * weight
new_node_features = tf.tensor_scatter_nd_add(
self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
return TensorGraphWithNodeFeature(
self.num_nodes, self.edge_src, self.edge_dst, new_node_features)
g = TensorGraphWithNodeFeature( # Edges: 0->1, 4->3, 2->2, 2->1
num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])
print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)
Original features: tf.Tensor([10. 0. 2. 5. -1. 0.], shape=(6,), dtype=float32) After propagating: tf.Tensor([10. 12. 4. 4. -1. 0.], shape=(6,), dtype=float32)
Definición de campos privados
Los campos de un tipo de extensión se pueden marcar como privados prefijándolos con un guión bajo (siguiendo las convenciones estándar de Python). Esto no afecta la forma en que TensorFlow trata los campos de ninguna manera; pero simplemente sirve como una señal para cualquier usuario del tipo de extensión de que esos campos son privados.
Personalización de TypeSpec
de ExtensionType
Cada clase ExtensionType
tiene una clase TypeSpec
correspondiente, que se crea automáticamente y se almacena como <extension_type_name>.Spec
. Para obtener más información, consulte la sección "Especificación de tipos anidados" más arriba.
Para personalizar el TypeSpec
, simplemente defina su propia clase anidada llamada Spec
, y ExtensionType
la usará como base para el TypeSpec
construido automáticamente. Puede personalizar la clase de Spec
de la siguiente manera:
- Anulando la representación imprimible predeterminada.
- Anulando el constructor predeterminado.
- Definición de métodos, métodos de clase, métodos estáticos y propiedades.
El siguiente ejemplo personaliza la clase MaskedTensor.Spec
para que sea más fácil de usar:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def with_values(self, new_values):
return MaskedTensor(new_values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
def __repr__(self):
return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
Despacho de API de tensor
Los tipos de extensión pueden ser "similares a tensores", en el sentido de que especializan o amplían la interfaz definida por el tipo tf.Tensor
. Los ejemplos de tipos de extensión similares a tensores incluyen RaggedTensor
, SparseTensor
y MaskedTensor
. Los decoradores de envío se pueden usar para anular el comportamiento predeterminado de las operaciones de TensorFlow cuando se aplican a tipos de extensión tipo tensor. TensorFlow actualmente define tres decoradores de despacho:
-
@tf.experimental.dispatch_for_api(tf_api)
-
@tf.experimental.dispatch_for_unary_elementwise_api(x_type)
-
@tf.experimental.dispatch_for_binary_elementwise_apis(x_type, y_type)
Despacho para una sola API
El decorador tf.experimental.dispatch_for_api
anula el comportamiento predeterminado de una operación de TensorFlow especificada cuando se llama con la firma especificada. Por ejemplo, puede usar este decorador para especificar cómo tf.stack
debe procesar los valores MaskedTensor
:
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
Esto anula la implementación predeterminada para tf.stack
cada vez que se llama con una lista de valores de MaskedTensor
(ya que el argumento de los values
se anota con typing.List[MaskedTensor]
):
x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])
<MaskedTensor [[1, 2, _], [_, 5, 6]]>
Para permitir que tf.stack
maneje listas de valores combinados de MaskedTensor
y Tensor
, puede refinar la anotación de tipo para el parámetro de values
y actualizar el cuerpo de la función de manera adecuada:
tf.experimental.unregister_dispatch_for(masked_stack)
def convert_to_masked_tensor(x):
if isinstance(x, MaskedTensor):
return x
else:
return MaskedTensor(x, tf.ones_like(x, tf.bool))
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
values = [convert_to_masked_tensor(v) for v in values]
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])
<MaskedTensor [[1, 2, _], [4, 5, 6], [1, 2, _]]>
Para obtener una lista de las API que se pueden anular, consulte la documentación de la API para tf.experimental.dispatch_for_api
.
Despacho para todas las API elementales unarias
El decorador tf.experimental.dispatch_for_unary_elementwise_apis
anula el comportamiento predeterminado de todas las operaciones de elemento unario (como tf.math.cos
) siempre que el valor del primer argumento (normalmente denominado x
) coincida con la anotación de tipo x_type
. La función decorada debe tomar dos argumentos:
-
api_func
: una función que toma un solo parámetro y realiza la operación por elementos (p. ej.,tf.abs
). -
x
: El primer argumento de la operación por elementos.
El siguiente ejemplo actualiza todas las operaciones de elementos unarios para manejar el tipo MaskedTensor
:
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def masked_tensor_unary_elementwise_api_handler(api_func, x):
return MaskedTensor(api_func(x.values), x.mask)
Esta función ahora se usará cada vez que se llame a una operación elemento-unario en un MaskedTensor
.
x = MaskedTensor([1, -2, -3], [True, False, True])
print(tf.abs(x))
<MaskedTensor [1, _, 3]>
print(tf.ones_like(x, dtype=tf.float32))
<MaskedTensor [1.0, _, 1.0]>
Despacho para todas las API elementales binarias
De manera similar, tf.experimental.dispatch_for_binary_elementwise_apis
se puede usar para actualizar todas las operaciones binarias elementales para manejar el tipo MaskedTensor
:
@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)
<MaskedTensor [[5, _, 1], [_, _, _]]>
Para obtener una lista de las API de elementwise que se anulan, consulte la documentación de la API para tf.experimental.dispatch_for_unary_elementwise_apis
y tf.experimental.dispatch_for_binary_elementwise_apis
.
Tipos de extensión por lotes
Un ExtensionType
se puede procesar por lotes si se puede usar una única instancia para representar un lote de valores. Por lo general, esto se logra agregando dimensiones de lote a todos los Tensor
anidados. Las siguientes API de TensorFlow requieren que cualquier entrada de tipo de extensión sea procesable por lotes:
-
tf.data.Dataset
(batch
,unbatch
,from_tensor_slices
) -
tf.Keras
(fit
,evaluate
,predict
) -
tf.map_fn
De manera predeterminada, BatchableExtensionType
crea valores por lotes al agrupar por lotes cualquier Tensor
s, CompositeTensor
s y ExtensionType
s anidados. Si esto no es apropiado para su clase, deberá usar tf.experimental.ExtensionTypeBatchEncoder
para anular este comportamiento predeterminado. Por ejemplo, no sería apropiado crear un lote de valores de tf.SparseTensor
simplemente apilando values
, indices
y campos dense_shape
de tensores dispersos individuales; en la mayoría de los casos, no puede apilar estos tensores, ya que tienen formas incompatibles. ; e incluso si pudiera, el resultado no sería un SparseTensor
válido.
Ejemplo de BatchableExtensionType: Red
Como ejemplo, considere una clase de Network
simple utilizada para el equilibrio de carga, que rastrea cuánto trabajo queda por hacer en cada nodo y cuánto ancho de banda está disponible para mover el trabajo entre los nodos:
class Network(tf.experimental.ExtensionType): # This version is not batchable.
work: tf.Tensor # work[n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[n1, n2] = bandwidth from n1->n2
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
Para que este tipo se pueda procesar por lotes, cambie el tipo base a BatchableExtensionType
y ajuste la forma de cada campo para incluir dimensiones de lote opcionales. El siguiente ejemplo también agrega un campo de shape
para realizar un seguimiento de la forma del lote. Este campo de shape
no es requerido por tf.data.Dataset
o tf.map_fn
, pero es requerido por tf.Keras
.
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
def network_repr(network):
work = network.work
bandwidth = network.bandwidth
if hasattr(work, 'numpy'):
work = ' '.join(str(work.numpy()).split())
if hasattr(bandwidth, 'numpy'):
bandwidth = ' '.join(str(bandwidth.numpy()).split())
return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
work=tf.stack([net1.work, net2.work]),
bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")
net1=<Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]> net2=<Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]> batch=<Network shape=(2,) work=[[5. 3. 8.] [3. 4. 2.]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>
Luego puede usar tf.data.Dataset
para iterar a través de un lote de redes:
dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
print(f"Batch element {i}: {network}")
Batch element 0: <Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]> Batch element 1: <Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>
Y también puede usar map_fn
para aplicar una función a cada elemento del lote:
def balance_work_greedy(network):
delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
delta /= 4
delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
new_work = network.work + tf.reduce_sum(delta, -1)
return Network(new_work, network.bandwidth)
tf.map_fn(balance_work_greedy, batch_of_networks)
<Network shape=(2,) work=[[5.5 1.25 9.25] [3. 4.75 1.25]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>
API de TensorFlow que admiten ExtensionTypes
@tf.función
tf.function es un decorador que precalcula gráficos de TensorFlow para funciones de Python, lo que puede mejorar sustancialmente el rendimiento de su código de TensorFlow. Los valores de tipo de extensión se pueden usar de forma transparente con @tf.function
.
class Pastry(tf.experimental.ExtensionType):
sweetness: tf.Tensor # 2d embedding that encodes sweetness
chewiness: tf.Tensor # 2d embedding that encodes chewiness
@tf.function
def combine_pastry_features(x: Pastry):
return (x.sweetness + x.chewiness) / 2
cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>
Si desea especificar explícitamente input_signature
para tf.function
, puede hacerlo utilizando TypeSpec
del tipo de extensión.
pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))
@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
return Pastry(x.sweetness + delta, x.chewiness)
increase_sweetness(cookie)
Pastry(sweetness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([2.2, 1.4], dtype=float32)>, chewiness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.8, 0.2], dtype=float32)>)
Funciones concretas
Las funciones concretas encapsulan gráficos trazados individuales creados por tf.function
. Los tipos de extensión se pueden usar de forma transparente con funciones concretas.
cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>
Operaciones de flujo de control
Los tipos de extensión son compatibles con las operaciones de flujo de control de TensorFlow:
# Example: using tf.cond to select between two MaskedTensors. Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
<MaskedTensor [1.0, _, 3.0]>
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])
<MaskedTensor [26.285717, 37.285698, 112.285736, _]>
Flujo de control de autógrafos
Los tipos de extensión también son compatibles con las declaraciones de flujo de control en tf.function (usando autograph). En el siguiente ejemplo, las instrucciones if
y for
se convierten automáticamente en operaciones tf.cond
y tf.while_loop
, que admiten tipos de extensión.
@tf.function
def fn(x, b):
if b:
x = MaskedTensor(x, tf.less(x, 0))
else:
x = MaskedTensor(x, tf.greater(x, 0))
for i in tf.range(5 if b else 7):
x = x.with_values(x.values + 1 / 2)
return x
print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))
<MaskedTensor [_, 0.5, _]> <MaskedTensor [4.5, _, 6.5]>
Keras
tf.keras es la API de alto nivel de TensorFlow para crear y entrenar modelos de aprendizaje profundo. Los tipos de extensión se pueden pasar como entradas a un modelo de Keras, pasar entre capas de Keras y devolverse por los modelos de Keras. Keras actualmente pone dos requisitos en los tipos de extensión:
- Deben ser agrupables (consulte "Tipos de extensión agrupables" más arriba).
- El debe tener un campo o propiedad llamada
shape
. Se supone queshape[0]
es la dimensión del lote.
Las siguientes dos subsecciones brindan ejemplos que muestran cómo se pueden usar los tipos de extensión con Keras.
Ejemplo de Keras: Network
Para el primer ejemplo, considere la clase de Network
definida en la sección "Tipos de extensión por lotes" anterior, que se puede usar para el trabajo de equilibrio de carga entre nodos. Su definición se repite aquí:
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
single_network = Network( # A single network w/ 4 nodes.
work=[8.0, 5, 12, 2],
bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])
batch_of_networks = Network( # Batch of 2 networks, each w/ 2 nodes.
work=[[8.0, 5], [3, 2]],
bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])
Puede definir una nueva capa de Keras que procese Network
s.
class BalanceNetworkLayer(tf.keras.layers.Layer):
"""Layer that balances work between nodes in a network.
Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
"""
def call(self, inputs):
# This function is defined above, in "Batchable ExtensionTypes" section.
return balance_work_greedy(inputs)
Luego puede usar estas capas para crear un modelo simple. Para introducir un ExtensionType
en un modelo, puede usar una capa tf.keras.layer.Input
con type_spec
establecido en el TypeSpec
del tipo de extensión. Si el modelo de Keras se utilizará para procesar lotes, type_spec
debe incluir la dimensión del lote.
input_spec = Network.Spec(shape=None,
work=tf.TensorSpec(None, tf.float32),
bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
BalanceNetworkLayer(),
])
Finalmente, puede aplicar el modelo a una sola red ya un lote de redes.
model(single_network)
<Network shape=() work=[ 9.25 5. 14. -1.25] bandwidth=[[0. 1. 2. 2.] [1. 0. 0. 2.] [2. 0. 0. 1.] [2. 2. 1. 0.]]>
model(batch_of_networks)
<Network shape=(2,) work=[[8.75 4.25] [3.25 1.75]] bandwidth=[[[0. 1.] [1. 0.]] [[0. 2.] [2. 0.]]]>
Ejemplo de Keras: MaskedTensor
En este ejemplo, MaskedTensor
se amplía para admitir Keras
. la shape
se define como una propiedad que se calcula a partir del campo de values
. Keras requiere que agregue esta propiedad tanto al tipo de extensión como a su TypeSpec
. MaskedTensor
también define una variable __name__
, que será necesaria para la serialización del SavedModel
(a continuación).
class MaskedTensor(tf.experimental.BatchableExtensionType):
# __name__ is required for serialization in SavedModel; see below for details.
__name__ = 'extension_type_colab.MaskedTensor'
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_default(self, default):
return tf.where(self.mask, self.values, default)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_shape(self):
return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
tf.TensorSpec(shape, self.mask.dtype))
Luego, los decoradores de despacho se usan para anular el comportamiento predeterminado de varias API de TensorFlow. Dado que estas API son utilizadas por las capas estándar de Keras (como la capa Dense
), anularlas nos permitirá usar esas capas con MaskedTensor
. A los efectos de este ejemplo, matmul
para tensores enmascarados se define para tratar los valores enmascarados como ceros (es decir, para no incluirlos en el producto).
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
return MaskedTensor(op(x.values), x.mask)
@tf.experimental.dispatch_for_binary_elementwise_apis(
Union[MaskedTensor, tf.Tensor],
Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
x = convert_to_masked_tensor(x)
y = convert_to_masked_tensor(y)
return MaskedTensor(op(x.values, y.values), x.mask & y.mask)
@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
transpose_a=False, transpose_b=False,
adjoint_a=False, adjoint_b=False,
a_is_sparse=False, b_is_sparse=False,
output_type=None):
if isinstance(a, MaskedTensor):
a = a.with_default(0)
if isinstance(b, MaskedTensor):
b = b.with_default(0)
return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
adjoint_b, a_is_sparse, b_is_sparse, output_type)
Luego puede construir un modelo de Keras que acepte entradas de MaskedTensor
, utilizando capas estándar de Keras:
input_spec = MaskedTensor.Spec([None, 2], tf.float32)
masked_tensor_model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
tf.keras.layers.Dense(16, activation="relu"),
tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
[[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))
Epoch 1/3 1/1 [==============================] - 1s 955ms/step - loss: 10.2833 Epoch 2/3 1/1 [==============================] - 0s 5ms/step - loss: 10.2833 Epoch 3/3 1/1 [==============================] - 0s 5ms/step - loss: 10.2833 tf.Tensor( [[-0.09944128] [-0.7225147 ] [-1.3020657 ]], shape=(3, 1), dtype=float32)
Modelo guardado
Un modelo guardado es un programa TensorFlow serializado que incluye pesos y cálculos. Se puede construir a partir de un modelo Keras oa partir de un modelo personalizado. En cualquier caso, los tipos de extensión se pueden usar de forma transparente con las funciones y métodos definidos por un modelo guardado.
SavedModel puede guardar modelos, capas y funciones que procesan tipos de extensión, siempre que los tipos de extensión tengan un campo __name__
. Este nombre se usa para registrar el tipo de extensión, por lo que se puede ubicar cuando se carga el modelo.
Ejemplo: guardar un modelo de Keras
Los modelos de Keras que usan tipos de extensión se pueden guardar usando SavedModel
.
masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)
2021-11-06 01:25:14.285250: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel. INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets <tf.Tensor: shape=(3, 1), dtype=float32, numpy= array([[-0.09944128], [-0.7225147 ], [-1.3020657 ]], dtype=float32)>
Ejemplo: guardar un modelo personalizado
El modelo guardado también se puede usar para guardar subclases personalizadas de tf.Module
con funciones que procesan tipos de extensión.
class CustomModule(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def grow(self, x: MaskedTensor):
"""Increase values in `x` by multiplying them by `self.v`."""
return MaskedTensor(x.values * self.v, x.mask)
module = CustomModule(100.0)
module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))
INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets <MaskedTensor [_, 200.0, _]>
Cargar un modelo guardado cuando ExtensionType no está disponible
Si carga un SavedModel
que usa un tipo de ExtensionType
, pero ese tipo ExtensionType
no está disponible (es decir, no se ha importado), verá una advertencia y TensorFlow volverá a usar un objeto de "tipo de extensión anónimo". Este objeto tendrá los mismos campos que el tipo original, pero no tendrá ninguna personalización adicional que haya agregado para el tipo, como métodos o propiedades personalizados.
Uso de ExtensionTypes con servicio de TensorFlow
Actualmente, el servicio de TensorFlow (y otros consumidores del diccionario de "firmas" de modelo guardado) requiere que todas las entradas y salidas sean tensores sin formato. Si desea utilizar TensorFlow con un modelo que utiliza tipos de extensión, puede agregar métodos de envoltura que componen o descomponen valores de tipo de extensión a partir de tensores. P.ej:
class CustomModuleWrapper(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def var_weighted_mean(self, x: MaskedTensor):
"""Mean value of unmasked values in x, weighted by self.v."""
x = MaskedTensor(x.values * self.v, x.mask)
return (tf.reduce_sum(x.with_default(0)) /
tf.reduce_sum(tf.cast(x.mask, x.dtype)))
@tf.function()
def var_weighted_mean_wrapper(self, x_values, x_mask):
"""Raw tensor wrapper for var_weighted_mean."""
return self.var_weighted_mean(MaskedTensor(x_values, x_mask))
module = CustomModuleWrapper([3., 2., 8., 5.])
module.var_weighted_mean_wrapper.get_concrete_function(
tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)
INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets <tf.Tensor: shape=(), dtype=float32, numpy=12.0>
conjuntos de datos
tf.data es una API que le permite crear canalizaciones de entrada complejas a partir de piezas simples y reutilizables. Su estructura de datos central es tf.data.Dataset
, que representa una secuencia de elementos, en la que cada elemento consta de uno o más componentes.
Creación de conjuntos de datos con tipos de extensión
Los conjuntos de datos se pueden crear a partir de valores de tipo de extensión mediante Dataset.from_tensors
, Dataset.from_tensor_slices
o Dataset.from_generator
:
ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
Pastry(sweetness=<tf.Tensor: shape=(), dtype=int32, numpy=5>, chewiness=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
print(value)
<MaskedTensor [0, 1, 2, 3]> <MaskedTensor [4, 5, 6, 7]> <MaskedTensor [8, 9, 10, 11]> <MaskedTensor [12, 13, 14, 15]> <MaskedTensor [16, 17, 18, 19]>
def value_gen():
for i in range(2, 7):
yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])
ds = tf.data.Dataset.from_generator(
value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]> <MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]> <MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]> <MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]> <MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>
Conjuntos de datos por lotes y sin lotes con tipos de extensión
Los conjuntos de datos con tipos de extensión se pueden procesar por lotes y sin lotes mediante Dataset.batch
y Dataset.unbatch
.
batched_ds = ds.batch(2)
for value in batched_ds:
print(value)
<MaskedTensor [[_, 1, _, 3, _, 5, _, 7, _, 9], [_, 1, 2, _, 4, 5, _, 7, 8, _]]> <MaskedTensor [[_, 1, 2, 3, _, 5, 6, 7, _, 9], [_, 1, 2, 3, 4, _, 6, 7, 8, 9]]> <MaskedTensor [[_, 1, 2, 3, 4, 5, _, 7, 8, 9]]>
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]> <MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]> <MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]> <MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]> <MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>