Tipos de extensión

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHubDescargar libreta

Configuración

!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile

Tipos de extensiones

Los tipos definidos por el usuario pueden hacer que los proyectos sean más legibles, modulares y mantenibles. Sin embargo, la mayoría de las API de TensorFlow tienen una compatibilidad muy limitada con los tipos de Python definidos por el usuario. Esto incluye API de alto nivel (como Keras , tf.function , tf.SavedModel ) y API de nivel inferior (como tf.while_loop y tf.concat ). Los tipos de extensión de TensorFlow se pueden usar para crear tipos orientados a objetos definidos por el usuario que funcionan sin problemas con las API de TensorFlow. Para crear un tipo de extensión, simplemente defina una clase de Python con tf.experimental.ExtensionType como base y use anotaciones de tipo para especificar el tipo de cada campo.

class TensorGraph(tf.experimental.ExtensionType):
  """A collection of labeled nodes connected by weighted edges."""
  edge_weights: tf.Tensor               # shape=[num_nodes, num_nodes]
  node_labels: Mapping[str, tf.Tensor]  # shape=[num_nodes]; dtype=any

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for missing/invalid values.

class CSRSparseMatrix(tf.experimental.ExtensionType):
  """Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
  values: tf.Tensor     # shape=[num_nonzero]; dtype=any
  col_index: tf.Tensor  # shape=[num_nonzero]; dtype=int64
  row_index: tf.Tensor  # shape=[num_rows+1]; dtype=int64

La clase base tf.experimental.ExtensionType funciona de manera similar a typing.NamedTuple y @dataclasses.dataclass de la biblioteca estándar de Python. En particular, agrega automáticamente un constructor y métodos especiales (como __repr__ y __eq__ ) en función de las anotaciones de tipo de campo.

Por lo general, los tipos de extensión tienden a caer en una de dos categorías:

  • Estructuras de datos , que agrupan una colección de valores relacionados y pueden proporcionar operaciones útiles basadas en esos valores. Las estructuras de datos pueden ser bastante generales (como el ejemplo anterior de TensorGraph ); o pueden personalizarse mucho para un modelo específico.

  • Tipos similares a tensores , que especializan o amplían el concepto de "tensor". Los tipos de esta categoría tienen un rank , una shape y, por lo general, un dtype ; y tiene sentido usarlos con operaciones Tensor (como tf.stack , tf.add o tf.matmul ). MaskedTensor y CSRSparseMatrix son ejemplos de tipos similares a tensores.

API compatibles

Los tipos de extensión son compatibles con las siguientes API de TensorFlow:

  • Keras : los tipos de extensión se pueden usar como entradas y salidas para los Models y Layers de Keras.
  • tf.data.Dataset : los tipos de extensión se pueden incluir en Datasets de datos y devolverlos los Iterators de conjuntos de datos.
  • Tensorflow hub : los tipos de extensión se pueden usar como entradas y salidas para los módulos tf.hub .
  • Modelo guardado : los tipos de extensión se pueden usar como entradas y salidas para funciones de SavedModel .
  • tf.function : los tipos de extensión se pueden usar como argumentos y devolver valores para funciones envueltas con el decorador @tf.function .
  • while loops : los tipos de extensión se pueden usar como variables de bucle en tf.while_loop , y se pueden usar como argumentos y valores de retorno para el cuerpo del while-loop.
  • condicionales : los tipos de extensión se pueden seleccionar condicionalmente usando tf.cond y tf.case .
  • py_function : los tipos de extensión se pueden usar como argumentos y devolver valores para el argumento func a tf.py_function .
  • Tensor ops : los tipos de extensión se pueden ampliar para admitir la mayoría de las operaciones de TensorFlow que aceptan entradas de Tensor (p. ej., tf.matmul , tf.gather y tf.reduce_sum ). Consulte la sección " Despacho " a continuación para obtener más información.
  • estrategia de distribución : los tipos de extensión se pueden usar como valores por réplica.

Para obtener más detalles, consulte la sección sobre "API de TensorFlow que admiten ExtensionTypes" a continuación.

Requisitos

tipos de campo

Se deben declarar todos los campos (también conocidos como variables de instancia) y se debe proporcionar una anotación de tipo para cada campo. Se admiten las siguientes anotaciones de tipo:

Escribe Ejemplo
enteros de Python i: int
pitón flota f: float
cadenas de pitón s: str
Booleanos de Python b: bool
Python Ninguno n: None
Formas de tensor shape: tf.TensorShape
Tipos de tensor dtype: tf.DType
tensores t: tf.Tensor
Tipos de extensiones mt: MyMaskedTensor
Tensores irregulares rt: tf.RaggedTensor
Tensores dispersos st: tf.SparseTensor
Sectores indexados s: tf.IndexedSlices
Tensores opcionales o: tf.experimental.Optional
Tipo de uniones int_or_float: typing.Union[int, float]
tuplas params: typing.Tuple[int, float, tf.Tensor, int]
tuplas de longitud Var lengths: typing.Tuple[int, ...]
Asignaciones tags: typing.Mapping[str, tf.Tensor]
Valores opcionales weight: typing.Optional[tf.Tensor]

Mutabilidad

Los tipos de extensión deben ser inmutables. Esto garantiza que los mecanismos de seguimiento de gráficos de TensorFlow puedan rastrearlos correctamente. Si desea mutar un valor de tipo de extensión, considere definir métodos que transformen valores. Por ejemplo, en lugar de definir un método set_mask para mutar un MaskedTensor , podría definir un método replace_mask que devuelva un nuevo MaskedTensor :

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def replace_mask(self, new_mask):
      self.values.shape.assert_is_compatible_with(new_mask.shape)
      return MaskedTensor(self.values, new_mask)

Funcionalidad añadida por ExtensionType

La clase base ExtensionType proporciona la siguiente funcionalidad:

  • Un constructor ( __init__ ).
  • Un método de representación imprimible ( __repr__ ).
  • Operadores de igualdad y desigualdad ( __eq__ ).
  • Un método de validación ( __validate__ ).
  • Inmutabilidad forzosa.
  • Un TypeSpec anidado.
  • Soporte de envío de API de tensor.

Consulte la sección "Personalizar tipos de extensión" a continuación para obtener más información sobre cómo personalizar esta funcionalidad.

Constructor

El constructor agregado por ExtensionType toma cada campo como un argumento con nombre (en el orden en que se enumeraron en la definición de clase). Este constructor verificará el tipo de cada parámetro y los convertirá cuando sea necesario. En particular, los campos Tensor se convierten usando tf.convert_to_tensor ; Los campos de Tuple se convierten en tuple s; y los campos de Mapping se convierten en dictados inmutables.

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])

# Fields are type-checked and converted to the declared types.
# E.g., mt.values is converted to a Tensor.
print(mt.values)
tf.Tensor(
[[1 2 3]
 [4 5 6]], shape=(2, 3), dtype=int32)

El constructor genera un TypeError si un valor de campo no se puede convertir a su tipo declarado:

try:
  MaskedTensor([1, 2, 3], None)
except TypeError as e:
  print(f"Got expected TypeError: {e}")
Got expected TypeError: mask: expected a Tensor, got None

El valor predeterminado para un campo se puede especificar estableciendo su valor en el nivel de clase:

class Pencil(tf.experimental.ExtensionType):
  color: str = "black"
  has_erasor: bool = True
  length: tf.Tensor = 1.0

Pencil()
Pencil(color='black', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
Pencil(length=0.5, color="blue")
Pencil(color='blue', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=0.5>)

Representación imprimible

ExtensionType agrega un método de representación imprimible predeterminado ( __repr__ ) que incluye el nombre de la clase y el valor de cada campo:

print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))
MaskedTensor(values=<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, mask=<tf.Tensor: shape=(3,), dtype=bool, numpy=array([ True,  True, False])>)

Operadores de igualdad

ExtensionType agrega operadores de igualdad predeterminados ( __eq__ y __ne__ ) que consideran dos valores iguales si tienen el mismo tipo y todos sus campos son iguales. Los campos tensoriales se consideran iguales si tienen la misma forma y son iguales para todos los elementos.

a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")
a == a: True
a == b: False
a == a.values: False

Método de validación

ExtensionType agrega un método __validate__ , que se puede anular para realizar comprobaciones de validación en los campos. Se ejecuta después de llamar al constructor y después de que los campos se hayan verificado y convertido a sus tipos declarados, por lo que puede asumir que todos los campos tienen sus tipos declarados.

El siguiente ejemplo actualiza MaskedTensor para validar las shape y dtype de sus campos:

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor
  def __validate__(self):
    self.values.shape.assert_is_compatible_with(self.mask.shape)
    assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
  MaskedTensor([1, 2, 3], [0, 1, 0])  # wrong dtype for mask.
except AssertionError as e:
  print(f"Got expected AssertionError: {e}")
Got expected AssertionError: mask.dtype must be bool
try:
  MaskedTensor([1, 2, 3], [True, False])  # shapes don't match.
except ValueError as e:
  print(f"Got expected ValueError: {e}")
Got expected ValueError: Shapes (3,) and (2,) are incompatible

inmutabilidad forzada

ExtensionType anula los métodos __setattr__ y __delattr__ para evitar la mutación, lo que garantiza que los valores de tipo de extensión sean inmutables.

mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
  mt.mask = [True, True, True]
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
try:
  mt.mask[0] = False
except TypeError as e:
  print(f"Got expected TypeError: {e}")
Got expected TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
try:
  del mt.mask
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.

Especificación de tipo anidado

Cada clase ExtensionType tiene una clase TypeSpec correspondiente, que se crea automáticamente y se almacena como <extension_type_name>.Spec .

Esta clase captura toda la información de un valor excepto los valores de cualquier tensor anidado. En particular, el TypeSpec para un valor se crea reemplazando cualquier Tensor, ExtensionType o CompositeTensor anidado con su TypeSpec .

class Player(tf.experimental.ExtensionType):
  name: tf.Tensor
  attributes: Mapping[str, tf.Tensor]

anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name)  # Records dtype and shape, but not the string value.
print(anne_spec.attributes)  # Records keys and TensorSpecs for values.
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
TensorSpec(shape=(), dtype=tf.string, name=None)
ImmutableDict({'height': TensorSpec(shape=(), dtype=tf.float32, name=None), 'speed': TensorSpec(shape=(), dtype=tf.float32, name=None)})

Los valores TypeSpec se pueden construir explícitamente o se pueden construir a partir de un valor ExtensionType usando tf.type_spec_from_value :

spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)

TensorFlow usa TypeSpec para dividir los valores en un componente estático y un componente dinámico :

  • El componente estático (que se fija en el momento de la construcción del gráfico) se codifica con tf.TypeSpec .
  • El componente dinámico (que puede variar cada vez que se ejecuta el gráfico) se codifica como una lista de tf.Tensor s.

Por ejemplo, tf.function vuelve sobre su función envuelta cada vez que un argumento tiene un TypeSpec nunca antes visto:

@tf.function
def anonymize_player(player):
  print("<<TRACING>>")
  return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
<<TRACING>>
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.3>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=28.1>}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.1>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=25.3>}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))
<<TRACING>>
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=11.0>, 'jump': <tf.Tensor: shape=(), dtype=float32, numpy=5.3>}))

Para obtener más información, consulte la Guía de funciones tf .

Personalización de tipos de extensión

Además de simplemente declarar campos y sus tipos, los tipos de extensión pueden:

  • Anule la representación imprimible predeterminada ( __repr__ ).
  • Definir métodos.
  • Defina métodos de clase y métodos estáticos.
  • Definir propiedades.
  • Anule el constructor predeterminado ( __init__ ).
  • Anule el operador de igualdad predeterminado ( __eq__ ).
  • Defina operadores (como __add__ y __lt__ ).
  • Declarar valores predeterminados para los campos.
  • Definir subclases.

Anular la representación imprimible predeterminada

Puede anular este operador de conversión de cadena predeterminado para los tipos de extensión. El siguiente ejemplo actualiza la clase MaskedTensor para generar una representación de cadena más legible cuando los valores se imprimen en modo Eager.

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for invalid values.

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

def masked_tensor_str(values, mask):
  if isinstance(values, tf.Tensor):
    if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
      return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
    else:
      return f'MaskedTensor(values={values}, mask={mask})'
  if len(values.shape) == 1:
    items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
  else:
    items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
  return '[%s]' % ', '.join(items)

mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])
print(mt)
<MaskedTensor [[1, 2, _], [4, _, 6]]>

Definición de métodos

Los tipos de extensión pueden definir métodos, como cualquier clase normal de Python. Por ejemplo, el tipo MaskedTensor podría definir un método with_default que devuelve una copia de self con valores enmascarados reemplazados por un valor default determinado. Los métodos se pueden anotar opcionalmente con el decorador @tf.function .

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 0, 3], dtype=int32)>

Definición de métodos de clase y métodos estáticos

Los tipos de extensión pueden definir métodos utilizando los decoradores @classmethod y @staticmethod . Por ejemplo, el tipo MaskedTensor podría definir un método de fábrica que enmascare cualquier elemento con un valor dado:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  @staticmethod
  def from_tensor_and_value_to_mask(values, value_to_mask):
    return MaskedTensor(values, values == value_to_mask)

x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)
<MaskedTensor [[_, 0, _], [_, 0, 0]]>

Definición de propiedades

Los tipos de extensión pueden definir propiedades utilizando el decorador @property , como cualquier clase normal de Python. Por ejemplo, el tipo MaskedTensor podría definir una propiedad dtype que es una abreviatura del dtype de los valores:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  @property
  def dtype(self):
    return self.values.dtype

MaskedTensor([1, 2, 3], [True, False, True]).dtype
tf.int32

Anulando el constructor predeterminado

Puede anular el constructor predeterminado para los tipos de extensión. Los constructores personalizados deben establecer un valor para cada campo declarado; y después de que regrese el constructor personalizado, se verificará el tipo de todos los campos y los valores se convertirán como se describe anteriormente.

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor
  def __init__(self, name, price, discount=0):
    self.name = name
    self.price = price * (1 - discount)

print(Toy("ball", 5.0, discount=0.2))  # On sale -- 20% off!
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)

Alternativamente, puede considerar dejar el constructor predeterminado como está, pero agregar uno o más métodos de fábrica. P.ej:

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor

  @staticmethod
  def new_toy_with_discount(name, price, discount):
    return Toy(name, price * (1 - discount))

print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)

Anulando el operador de igualdad predeterminado ( __eq__ )

Puede anular el operador __eq__ predeterminado para los tipos de extensión. El siguiente ejemplo actualiza MaskedTensor para ignorar los elementos enmascarados al comparar la igualdad.

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def __eq__(self, other):
    result = tf.math.equal(self.values, other.values)
    result = result | ~(self.mask & other.mask)
    return tf.reduce_all(result)

x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)
tf.Tensor(True, shape=(), dtype=bool)

Uso de referencias directas

Si el tipo de un campo aún no se ha definido, puede usar una cadena que contenga el nombre del tipo en su lugar. En el siguiente ejemplo, la cadena "Node" se usa para anotar el campo children porque el tipo de Node aún no se ha definido (totalmente).

class Node(tf.experimental.ExtensionType):
  value: tf.Tensor
  children: Tuple["Node", ...] = ()

Node(3, [Node(5), Node(2)])
Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=3>, children=(Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>, children=()), Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=2>, children=())))

Definición de subclases

Los tipos de extensión se pueden subclasificar utilizando la sintaxis estándar de Python. Las subclases de tipo de extensión pueden agregar nuevos campos, métodos y propiedades; y puede anular el constructor, la representación imprimible y el operador de igualdad. El siguiente ejemplo define una clase TensorGraph básica que usa tres campos Tensor para codificar un conjunto de bordes entre nodos. Luego define una subclase que agrega un campo Tensor para registrar un "valor de característica" para cada nodo. La subclase también define un método para propagar los valores de las características a lo largo de los bordes.

class TensorGraph(tf.experimental.ExtensionType):
  num_nodes: tf.Tensor
  edge_src: tf.Tensor   # edge_src[e] = index of src node for edge e.
  edge_dst: tf.Tensor   # edge_dst[e] = index of dst node for edge e.

class TensorGraphWithNodeFeature(TensorGraph):
  node_features: tf.Tensor  # node_features[n] = feature value for node n.

  def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
    updates = tf.gather(self.node_features, self.edge_src) * weight
    new_node_features = tf.tensor_scatter_nd_add(
        self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
    return TensorGraphWithNodeFeature(
        self.num_nodes, self.edge_src, self.edge_dst, new_node_features)

g = TensorGraphWithNodeFeature(  # Edges: 0->1, 4->3, 2->2, 2->1
    num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
    node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])

print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)
Original features: tf.Tensor([10.  0.  2.  5. -1.  0.], shape=(6,), dtype=float32)
After propagating: tf.Tensor([10. 12.  4.  4. -1.  0.], shape=(6,), dtype=float32)

Definición de campos privados

Los campos de un tipo de extensión se pueden marcar como privados prefijándolos con un guión bajo (siguiendo las convenciones estándar de Python). Esto no afecta la forma en que TensorFlow trata los campos de ninguna manera; pero simplemente sirve como una señal para cualquier usuario del tipo de extensión de que esos campos son privados.

Personalización de TypeSpec de ExtensionType

Cada clase ExtensionType tiene una clase TypeSpec correspondiente, que se crea automáticamente y se almacena como <extension_type_name>.Spec . Para obtener más información, consulte la sección "Especificación de tipos anidados" más arriba.

Para personalizar el TypeSpec , simplemente defina su propia clase anidada llamada Spec , y ExtensionType la usará como base para el TypeSpec construido automáticamente. Puede personalizar la clase de Spec de la siguiente manera:

  • Anulando la representación imprimible predeterminada.
  • Anulando el constructor predeterminado.
  • Definición de métodos, métodos de clase, métodos estáticos y propiedades.

El siguiente ejemplo personaliza la clase MaskedTensor.Spec para que sea más fácil de usar:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def with_values(self, new_values):
    return MaskedTensor(new_values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    def __repr__(self):
      return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

Despacho de API de tensor

Los tipos de extensión pueden ser "similares a tensores", en el sentido de que especializan o amplían la interfaz definida por el tipo tf.Tensor . Los ejemplos de tipos de extensión similares a tensores incluyen RaggedTensor , SparseTensor y MaskedTensor . Los decoradores de envío se pueden usar para anular el comportamiento predeterminado de las operaciones de TensorFlow cuando se aplican a tipos de extensión tipo tensor. TensorFlow actualmente define tres decoradores de despacho:

Despacho para una sola API

El decorador tf.experimental.dispatch_for_api anula el comportamiento predeterminado de una operación de TensorFlow especificada cuando se llama con la firma especificada. Por ejemplo, puede usar este decorador para especificar cómo tf.stack debe procesar los valores MaskedTensor :

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))

Esto anula la implementación predeterminada para tf.stack cada vez que se llama con una lista de valores de MaskedTensor (ya que el argumento de los values se anota con typing.List[MaskedTensor] ):

x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])
<MaskedTensor [[1, 2, _], [_, 5, 6]]>

Para permitir que tf.stack maneje listas de valores combinados de MaskedTensor y Tensor , puede refinar la anotación de tipo para el parámetro de values y actualizar el cuerpo de la función de manera adecuada:

tf.experimental.unregister_dispatch_for(masked_stack)

def convert_to_masked_tensor(x):
  if isinstance(x, MaskedTensor):
    return x
  else:
    return MaskedTensor(x, tf.ones_like(x, tf.bool))

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
  values = [convert_to_masked_tensor(v) for v in values]
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])
<MaskedTensor [[1, 2, _], [4, 5, 6], [1, 2, _]]>

Para obtener una lista de las API que se pueden anular, consulte la documentación de la API para tf.experimental.dispatch_for_api .

Despacho para todas las API elementales unarias

El decorador tf.experimental.dispatch_for_unary_elementwise_apis anula el comportamiento predeterminado de todas las operaciones de elemento unario (como tf.math.cos ) siempre que el valor del primer argumento (normalmente denominado x ) coincida con la anotación de tipo x_type . La función decorada debe tomar dos argumentos:

  • api_func : una función que toma un solo parámetro y realiza la operación por elementos (p. ej., tf.abs ).
  • x : El primer argumento de la operación por elementos.

El siguiente ejemplo actualiza todas las operaciones de elementos unarios para manejar el tipo MaskedTensor :

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
 def masked_tensor_unary_elementwise_api_handler(api_func, x):
   return MaskedTensor(api_func(x.values), x.mask)

Esta función ahora se usará cada vez que se llame a una operación elemento-unario en un MaskedTensor .

x = MaskedTensor([1, -2, -3], [True, False, True])
 print(tf.abs(x))
<MaskedTensor [1, _, 3]>
print(tf.ones_like(x, dtype=tf.float32))
<MaskedTensor [1.0, _, 1.0]>

Despacho para todas las API elementales binarias

De manera similar, tf.experimental.dispatch_for_binary_elementwise_apis se puede usar para actualizar todas las operaciones binarias elementales para manejar el tipo MaskedTensor :

@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
  return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)
<MaskedTensor [[5, _, 1], [_, _, _]]>

Para obtener una lista de las API de elementwise que se anulan, consulte la documentación de la API para tf.experimental.dispatch_for_unary_elementwise_apis y tf.experimental.dispatch_for_binary_elementwise_apis .

Tipos de extensión por lotes

Un ExtensionType se puede procesar por lotes si se puede usar una única instancia para representar un lote de valores. Por lo general, esto se logra agregando dimensiones de lote a todos los Tensor anidados. Las siguientes API de TensorFlow requieren que cualquier entrada de tipo de extensión sea procesable por lotes:

De manera predeterminada, BatchableExtensionType crea valores por lotes al agrupar por lotes cualquier Tensor s, CompositeTensor s y ExtensionType s anidados. Si esto no es apropiado para su clase, deberá usar tf.experimental.ExtensionTypeBatchEncoder para anular este comportamiento predeterminado. Por ejemplo, no sería apropiado crear un lote de valores de tf.SparseTensor simplemente apilando values , indices y campos dense_shape de tensores dispersos individuales; en la mayoría de los casos, no puede apilar estos tensores, ya que tienen formas incompatibles. ; e incluso si pudiera, el resultado no sería un SparseTensor válido.

Ejemplo de BatchableExtensionType: Red

Como ejemplo, considere una clase de Network simple utilizada para el equilibrio de carga, que rastrea cuánto trabajo queda por hacer en cada nodo y cuánto ancho de banda está disponible para mover el trabajo entre los nodos:

class Network(tf.experimental.ExtensionType):  # This version is not batchable.
  work: tf.Tensor       # work[n] = work left to do at node n
  bandwidth: tf.Tensor  # bandwidth[n1, n2] = bandwidth from n1->n2

net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])

Para que este tipo se pueda procesar por lotes, cambie el tipo base a BatchableExtensionType y ajuste la forma de cada campo para incluir dimensiones de lote opcionales. El siguiente ejemplo también agrega un campo de shape para realizar un seguimiento de la forma del lote. Este campo de shape no es requerido por tf.data.Dataset o tf.map_fn , pero es requerido por tf.Keras .

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape.  A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)

def network_repr(network):
  work = network.work
  bandwidth = network.bandwidth
  if hasattr(work, 'numpy'):
    work = ' '.join(str(work.numpy()).split())
  if hasattr(bandwidth, 'numpy'):
    bandwidth = ' '.join(str(bandwidth.numpy()).split())
  return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
    work=tf.stack([net1.work, net2.work]),
    bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")
net1=<Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]>
net2=<Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>
batch=<Network shape=(2,) work=[[5. 3. 8.] [3. 4. 2.]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>

Luego puede usar tf.data.Dataset para iterar a través de un lote de redes:

dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
  print(f"Batch element {i}: {network}")
Batch element 0: <Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]>
Batch element 1: <Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>

Y también puede usar map_fn para aplicar una función a cada elemento del lote:

def balance_work_greedy(network):
  delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
  delta /= 4
  delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
  new_work = network.work + tf.reduce_sum(delta, -1)
  return Network(new_work, network.bandwidth)

tf.map_fn(balance_work_greedy, batch_of_networks)
<Network shape=(2,) work=[[5.5 1.25 9.25] [3. 4.75 1.25]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>

API de TensorFlow que admiten ExtensionTypes

@tf.función

tf.function es un decorador que precalcula gráficos de TensorFlow para funciones de Python, lo que puede mejorar sustancialmente el rendimiento de su código de TensorFlow. Los valores de tipo de extensión se pueden usar de forma transparente con @tf.function .

class Pastry(tf.experimental.ExtensionType):
  sweetness: tf.Tensor  # 2d embedding that encodes sweetness
  chewiness: tf.Tensor  # 2d embedding that encodes chewiness

@tf.function
def combine_pastry_features(x: Pastry):
  return (x.sweetness + x.chewiness) / 2

cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>

Si desea especificar explícitamente input_signature para tf.function , puede hacerlo utilizando TypeSpec del tipo de extensión.

pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))

@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
  return Pastry(x.sweetness + delta, x.chewiness)

increase_sweetness(cookie)
Pastry(sweetness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([2.2, 1.4], dtype=float32)>, chewiness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.8, 0.2], dtype=float32)>)

Funciones concretas

Las funciones concretas encapsulan gráficos trazados individuales creados por tf.function . Los tipos de extensión se pueden usar de forma transparente con funciones concretas.

cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>

Operaciones de flujo de control

Los tipos de extensión son compatibles con las operaciones de flujo de control de TensorFlow:

# Example: using tf.cond to select between two MaskedTensors.  Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
<MaskedTensor [1.0, _, 3.0]>
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
  return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])
<MaskedTensor [26.285717, 37.285698, 112.285736, _]>

Flujo de control de autógrafos

Los tipos de extensión también son compatibles con las declaraciones de flujo de control en tf.function (usando autograph). En el siguiente ejemplo, las instrucciones if y for se convierten automáticamente en operaciones tf.cond y tf.while_loop , que admiten tipos de extensión.

@tf.function
def fn(x, b):
  if b:
    x = MaskedTensor(x, tf.less(x, 0))
  else:
    x = MaskedTensor(x, tf.greater(x, 0))
  for i in tf.range(5 if b else 7):
    x = x.with_values(x.values + 1 / 2)
  return x

print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))
<MaskedTensor [_, 0.5, _]>
<MaskedTensor [4.5, _, 6.5]>

Keras

tf.keras es la API de alto nivel de TensorFlow para crear y entrenar modelos de aprendizaje profundo. Los tipos de extensión se pueden pasar como entradas a un modelo de Keras, pasar entre capas de Keras y devolverse por los modelos de Keras. Keras actualmente pone dos requisitos en los tipos de extensión:

  • Deben ser agrupables (consulte "Tipos de extensión agrupables" más arriba).
  • El debe tener un campo o propiedad llamada shape . Se supone que shape[0] es la dimensión del lote.

Las siguientes dos subsecciones brindan ejemplos que muestran cómo se pueden usar los tipos de extensión con Keras.

Ejemplo de Keras: Network

Para el primer ejemplo, considere la clase de Network definida en la sección "Tipos de extensión por lotes" anterior, que se puede usar para el trabajo de equilibrio de carga entre nodos. Su definición se repite aquí:

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape.  A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)
single_network = Network(  # A single network w/ 4 nodes.
    work=[8.0, 5, 12, 2],
    bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])

batch_of_networks = Network(  # Batch of 2 networks, each w/ 2 nodes.
    work=[[8.0, 5], [3, 2]],
    bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])

Puede definir una nueva capa de Keras que procese Network s.

class BalanceNetworkLayer(tf.keras.layers.Layer):
  """Layer that balances work between nodes in a network.

  Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
  """
  def call(self, inputs):
    # This function is defined above, in "Batchable ExtensionTypes" section.
    return balance_work_greedy(inputs)

Luego puede usar estas capas para crear un modelo simple. Para introducir un ExtensionType en un modelo, puede usar una capa tf.keras.layer.Input con type_spec establecido en el TypeSpec del tipo de extensión. Si el modelo de Keras se utilizará para procesar lotes, type_spec debe incluir la dimensión del lote.

input_spec = Network.Spec(shape=None,
                          work=tf.TensorSpec(None, tf.float32),
                          bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    BalanceNetworkLayer(),
    ])

Finalmente, puede aplicar el modelo a una sola red ya un lote de redes.

model(single_network)
<Network shape=() work=[ 9.25 5. 14. -1.25] bandwidth=[[0. 1. 2. 2.] [1. 0. 0. 2.] [2. 0. 0. 1.] [2. 2. 1. 0.]]>
model(batch_of_networks)
<Network shape=(2,) work=[[8.75 4.25] [3.25 1.75]] bandwidth=[[[0. 1.] [1. 0.]] [[0. 2.] [2. 0.]]]>

Ejemplo de Keras: MaskedTensor

En este ejemplo, MaskedTensor se amplía para admitir Keras . la shape se define como una propiedad que se calcula a partir del campo de values . Keras requiere que agregue esta propiedad tanto al tipo de extensión como a su TypeSpec . MaskedTensor también define una variable __name__ , que será necesaria para la serialización del SavedModel (a continuación).

class MaskedTensor(tf.experimental.BatchableExtensionType):
  # __name__ is required for serialization in SavedModel; see below for details.
  __name__ = 'extension_type_colab.MaskedTensor'

  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

    def with_shape(self):
      return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
                               tf.TensorSpec(shape, self.mask.dtype))

Luego, los decoradores de despacho se usan para anular el comportamiento predeterminado de varias API de TensorFlow. Dado que estas API son utilizadas por las capas estándar de Keras (como la capa Dense ), anularlas nos permitirá usar esas capas con MaskedTensor . A los efectos de este ejemplo, matmul para tensores enmascarados se define para tratar los valores enmascarados como ceros (es decir, para no incluirlos en el producto).

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
 return MaskedTensor(op(x.values), x.mask)

@tf.experimental.dispatch_for_binary_elementwise_apis(
    Union[MaskedTensor, tf.Tensor],
    Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
  x = convert_to_masked_tensor(x)
  y = convert_to_masked_tensor(y)
  return MaskedTensor(op(x.values, y.values), x.mask & y.mask)

@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
                  transpose_a=False, transpose_b=False,
                  adjoint_a=False, adjoint_b=False,
                  a_is_sparse=False, b_is_sparse=False,
                  output_type=None):
  if isinstance(a, MaskedTensor):
    a = a.with_default(0)
  if isinstance(b, MaskedTensor):
    b = b.with_default(0)
  return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
                   adjoint_b, a_is_sparse, b_is_sparse, output_type)

Luego puede construir un modelo de Keras que acepte entradas de MaskedTensor , utilizando capas estándar de Keras:

input_spec = MaskedTensor.Spec([None, 2], tf.float32)

masked_tensor_model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    tf.keras.layers.Dense(16, activation="relu"),
    tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
                  [[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))
Epoch 1/3
1/1 [==============================] - 1s 955ms/step - loss: 10.2833
Epoch 2/3
1/1 [==============================] - 0s 5ms/step - loss: 10.2833
Epoch 3/3
1/1 [==============================] - 0s 5ms/step - loss: 10.2833
tf.Tensor(
[[-0.09944128]
 [-0.7225147 ]
 [-1.3020657 ]], shape=(3, 1), dtype=float32)

Modelo guardado

Un modelo guardado es un programa TensorFlow serializado que incluye pesos y cálculos. Se puede construir a partir de un modelo Keras oa partir de un modelo personalizado. En cualquier caso, los tipos de extensión se pueden usar de forma transparente con las funciones y métodos definidos por un modelo guardado.

SavedModel puede guardar modelos, capas y funciones que procesan tipos de extensión, siempre que los tipos de extensión tengan un campo __name__ . Este nombre se usa para registrar el tipo de extensión, por lo que se puede ubicar cuando se carga el modelo.

Ejemplo: guardar un modelo de Keras

Los modelos de Keras que usan tipos de extensión se pueden guardar usando SavedModel .

masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)
2021-11-06 01:25:14.285250: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel.
INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets
INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[-0.09944128],
       [-0.7225147 ],
       [-1.3020657 ]], dtype=float32)>

Ejemplo: guardar un modelo personalizado

El modelo guardado también se puede usar para guardar subclases personalizadas de tf.Module con funciones que procesan tipos de extensión.

class CustomModule(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def grow(self, x: MaskedTensor):
    """Increase values in `x` by multiplying them by `self.v`."""
    return MaskedTensor(x.values * self.v, x.mask)

module = CustomModule(100.0)

module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
                                                    dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))
INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets
INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets
<MaskedTensor [_, 200.0, _]>

Cargar un modelo guardado cuando ExtensionType no está disponible

Si carga un SavedModel que usa un tipo de ExtensionType , pero ese tipo ExtensionType no está disponible (es decir, no se ha importado), verá una advertencia y TensorFlow volverá a usar un objeto de "tipo de extensión anónimo". Este objeto tendrá los mismos campos que el tipo original, pero no tendrá ninguna personalización adicional que haya agregado para el tipo, como métodos o propiedades personalizados.

Uso de ExtensionTypes con servicio de TensorFlow

Actualmente, el servicio de TensorFlow (y otros consumidores del diccionario de "firmas" de modelo guardado) requiere que todas las entradas y salidas sean tensores sin formato. Si desea utilizar TensorFlow con un modelo que utiliza tipos de extensión, puede agregar métodos de envoltura que componen o descomponen valores de tipo de extensión a partir de tensores. P.ej:

class CustomModuleWrapper(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def var_weighted_mean(self, x: MaskedTensor):
    """Mean value of unmasked values in x, weighted by self.v."""
    x = MaskedTensor(x.values * self.v, x.mask)
    return (tf.reduce_sum(x.with_default(0)) /
            tf.reduce_sum(tf.cast(x.mask, x.dtype)))

  @tf.function()
  def var_weighted_mean_wrapper(self, x_values, x_mask):
    """Raw tensor wrapper for var_weighted_mean."""
    return self.var_weighted_mean(MaskedTensor(x_values, x_mask))

module = CustomModuleWrapper([3., 2., 8., 5.])

module.var_weighted_mean_wrapper.get_concrete_function(
    tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)
INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets
INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets
<tf.Tensor: shape=(), dtype=float32, numpy=12.0>

conjuntos de datos

tf.data es una API que le permite crear canalizaciones de entrada complejas a partir de piezas simples y reutilizables. Su estructura de datos central es tf.data.Dataset , que representa una secuencia de elementos, en la que cada elemento consta de uno o más componentes.

Creación de conjuntos de datos con tipos de extensión

Los conjuntos de datos se pueden crear a partir de valores de tipo de extensión mediante Dataset.from_tensors , Dataset.from_tensor_slices o Dataset.from_generator :

ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
Pastry(sweetness=<tf.Tensor: shape=(), dtype=int32, numpy=5>, chewiness=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
  print(value)
<MaskedTensor [0, 1, 2, 3]>
<MaskedTensor [4, 5, 6, 7]>
<MaskedTensor [8, 9, 10, 11]>
<MaskedTensor [12, 13, 14, 15]>
<MaskedTensor [16, 17, 18, 19]>
def value_gen():
  for i in range(2, 7):
    yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])

ds = tf.data.Dataset.from_generator(
    value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
  print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]>
<MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]>
<MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]>
<MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]>
<MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>

Conjuntos de datos por lotes y sin lotes con tipos de extensión

Los conjuntos de datos con tipos de extensión se pueden procesar por lotes y sin lotes mediante Dataset.batch y Dataset.unbatch .

batched_ds = ds.batch(2)
for value in batched_ds:
  print(value)
<MaskedTensor [[_, 1, _, 3, _, 5, _, 7, _, 9], [_, 1, 2, _, 4, 5, _, 7, 8, _]]>
<MaskedTensor [[_, 1, 2, 3, _, 5, 6, 7, _, 9], [_, 1, 2, 3, 4, _, 6, 7, 8, 9]]>
<MaskedTensor [[_, 1, 2, 3, 4, 5, _, 7, 8, 9]]>
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
  print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]>
<MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]>
<MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]>
<MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]>
<MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>