ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf.experimental.BatchableExtensionType

An ExtensionType that can be batched and unbatched.

Inherits From: ExtensionType

BatchableExtensionTypes can be used with APIs that require batching or unbatching, including Keras, tf.data.Dataset, and tf.map_fn. E.g.:

class Vehicle(BatchableExtensionType):
  top_speed: tf.Tensor
  mpg: tf.Tensor
batch = Vehicle([120, 150, 80], [30, 40, 12])
tf.map_fn(lambda vehicle: vehicle.top_speed * vehicle.mpg, batch,
          fn_output_signature=tf.int32).numpy()
array([3600, 6000,  960], dtype=int32)

An ExtensionTypeBatchEncoder is used by these APIs to encode ExtensionType values. The default encoder assumes that values can be stacked, unstacked, or concatenated by simply stacking, unstacking, or concatenating every nested Tensor, ExtensionType, CompositeTensor, or TensorShape field. Extension types where this is not the case will need to override __batch_encoder__ with a custom ExtensionTypeBatchEncoder. See tf.experimental.ExtensionTypeBatchEncoder for more details.

Methods

__eq__

View source

Return self==value.

__ne__

View source

Return self!=value.