tf.experimental.ExtensionTypeBatchEncoder

Class used to encode and decode extension type values for batching.

In order to be batched and unbatched by APIs such as tf.data.Dataset, tf.keras, and tf.map_fn, extension type values must be encoded as a list of tf.Tensors, where stacking, unstacking, or concatenating these encoded tensors and then decoding the result must be equivalent to stacking, unstacking, or concatenating the original values. ExtensionTypeBatchEncoders are responsible for implementing this encoding.

The default ExtensionTypeBatchEncoder that is used by BatchableExtensionType assumes that extension type values can be stacked, unstacked, or concatenated by simply stacking, unstacking, or concatenating every nested Tensor, ExtensionType, CompositeTensor, and TensorShape field.

Extension types where this is not the case will need to override __batch_encoder__ with a custom encoder that overrides the batch, unbatch, encode, and decode methods. E.g.:

class CustomBatchEncoder(ExtensionTypeBatchEncoder):
  pass # Override batch(), unbatch(), encode(), and decode().
class CustomType(BatchableExtensionType):
  x: tf.Tensor
  y: tf.Tensor
  shape: tf.TensorShape
  __batch_encoder__ = CustomBatchEncoder()

For example, tf.RaggedTensor and tf.SparseTensor both use custom batch encodings which define ops to "box" and "unbox" individual values into tf.variant tensors.

Methods

batch

View source

Returns the TypeSpec representing a batch of values described by spec.

The default definition returns a TypeSpec that is equal to spec, except that an outer axis with size batch_size is added to every nested TypeSpec and TensorShape field. Subclasses may override this default definition, when necessary.

Args
spec The TypeSpec for an individual value.
batch_size An int indicating the number of values that are batched together, or None if the batch size is not known.

Returns
A TypeSpec for a batch of values.

decode

View source

Decodes value from a batchable tensor encoding.

See encode for a description of the default encoding. Subclasses may override this default definition, when necessary.

Args
spec The TypeSpec for the result value. If encoded values with spec s were batched, then spec should be s.batch(batch_size); or if encoded values with spec s were unbatched, then spec should be s.unbatch().
encoded_value A nest of values returned by encode; or a nest of values that was formed by stacking, unstacking, or concatenating the corresponding elements of values returned by encode.

Returns
A value compatible with type_spec.

encode

View source

Encodes value as a nest of batchable Tensors or CompositeTensors.

The default definition returns a flat tuple of all the Tensors, CompositeTensors, and ExtensionTypes from a depth-first traversal of value's fields. Subclasses may override this default definition, when necessary.

Args
spec The TypeSpec of the value to encode.
value A value compatible with spec.
minimum_rank The minimum rank for the returned Tensors, CompositeTensors, and ExtensionType values. This can be used to ensure that the encoded values can be unbatched this number of times. If minimum_rank>0, then t.shape[:minimum_rank] must be compatible for all values t returned by encode.

Returns
A nest (as defined by tf.nest) of tf.Tensors, batchable tf.CompositeTensors, or tf.ExtensionTypes. Stacking, unstacking, or concatenating these encoded values and then decoding the result must be equivalent to stacking, unstacking, or concatenating the original values.

encoding_specs

View source

Returns a list of TensorSpec(s) describing the encoding for spec.

See encode for a description of the default encoding. Subclasses may override this default definition, when necessary.

Args
spec The TypeSpec whose encoding should be described.

Returns
A nest (as defined by tf.nest) oftf.TypeSpec, describing the values that are returned byself.encode(spec, ...)`. All TypeSpecs in this nest must be batchable.

unbatch

View source

Returns the TypeSpec for a single unbatched element in spec.

The default definition returns a TypeSpec that is equal to spec, except that the outermost axis is removed from every nested TypeSpec, and TensorShape field. Subclasses may override this default definition, when necessary.

Args
spec The TypeSpec for a batch of values.

Returns
A TypeSpec for an individual value.