tf.experimental.BatchableExtensionType
Stay organized with collections
Save and categorize content based on your preferences.
An ExtensionType that can be batched and unbatched.
Inherits From: ExtensionType
tf.experimental.BatchableExtensionType(
*args, **kwargs
)
Used in the notebooks
BatchableExtensionType
s can be used with APIs that require batching or
unbatching, including Keras
, tf.data.Dataset
, and tf.map_fn
. E.g.:
class Vehicle(tf.experimental.BatchableExtensionType):
top_speed: tf.Tensor
mpg: tf.Tensor
batch = Vehicle([120, 150, 80], [30, 40, 12])
tf.map_fn(lambda vehicle: vehicle.top_speed * vehicle.mpg, batch,
fn_output_signature=tf.int32).numpy()
array([3600, 6000, 960], dtype=int32)
An ExtensionTypeBatchEncoder
is used by these APIs to encode ExtensionType
values. The default encoder assumes that values can be stacked, unstacked, or
concatenated by simply stacking, unstacking, or concatenating every nested
Tensor
, ExtensionType
, CompositeTensor
, or TensorShape
field.
Extension types where this is not the case will need to override
__batch_encoder__
with a custom ExtensionTypeBatchEncoder
. See
tf.experimental.ExtensionTypeBatchEncoder
for more details.
Methods
__eq__
View source
__eq__(
other
)
Return self==value.
__ne__
View source
__ne__(
other
)
Return self!=value.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-04-26 UTC.
[null,null,["Last updated 2024-04-26 UTC."],[],[],null,["# tf.experimental.BatchableExtensionType\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/framework/extension_type.py#L831-L857) |\n\nAn ExtensionType that can be batched and unbatched.\n\nInherits From: [`ExtensionType`](../../tf/experimental/ExtensionType)\n\n#### View aliases\n\n\n**Compat aliases for migration**\n\nSee\n[Migration guide](https://www.tensorflow.org/guide/migrate) for\nmore details.\n\n[`tf.compat.v1.experimental.BatchableExtensionType`](https://www.tensorflow.org/api_docs/python/tf/experimental/BatchableExtensionType)\n\n\u003cbr /\u003e\n\n tf.experimental.BatchableExtensionType(\n *args, **kwargs\n )\n\n### Used in the notebooks\n\n| Used in the guide |\n|----------------------------------------------------------------------|\n| - [Extension types](https://www.tensorflow.org/guide/extension_type) |\n\n`BatchableExtensionType`s can be used with APIs that require batching or\nunbatching, including `Keras`, [`tf.data.Dataset`](../../tf/data/Dataset), and [`tf.map_fn`](../../tf/map_fn). E.g.: \n\n class Vehicle(tf.experimental.BatchableExtensionType):\n top_speed: tf.Tensor\n mpg: tf.Tensor\n batch = Vehicle([120, 150, 80], [30, 40, 12])\n tf.map_fn(lambda vehicle: vehicle.top_speed * vehicle.mpg, batch,\n fn_output_signature=tf.int32).numpy()\n array([3600, 6000, 960], dtype=int32)\n\nAn `ExtensionTypeBatchEncoder` is used by these APIs to encode `ExtensionType`\nvalues. The default encoder assumes that values can be stacked, unstacked, or\nconcatenated by simply stacking, unstacking, or concatenating every nested\n`Tensor`, `ExtensionType`, `CompositeTensor`, or `TensorShape` field.\nExtension types where this is not the case will need to override\n`__batch_encoder__` with a custom `ExtensionTypeBatchEncoder`. See\n[`tf.experimental.ExtensionTypeBatchEncoder`](../../tf/experimental/ExtensionTypeBatchEncoder) for more details.\n\nMethods\n-------\n\n### `__eq__`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/framework/extension_type.py#L276-L305) \n\n __eq__(\n other\n )\n\nReturn self==value.\n\n### `__ne__`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/framework/extension_type.py#L307-L312) \n\n __ne__(\n other\n )\n\nReturn self!=value."]]