View source on GitHub

A transformation that batches ragged elements into tf.RaggedTensors.

This transformation combines multiple consecutive elements of the input dataset into a single element.

Like, the components of the resulting element will have an additional outer dimension, which will be batch_size (or N % batch_size for the last element if batch_size does not divide the number of input elements N evenly and drop_remainder is False). If your program depends on the batches having the same outer dimension, you should set the drop_remainder argument to True to prevent the smaller batch from being produced.

Unlike, the input elements to be batched may have different shapes, and each batch will be encoded as a tf.RaggedTensor. Example:

dataset =
dataset = x: tf.range(x))
dataset = dataset.apply(
for batch in dataset:
<tf.RaggedTensor [[], [0]]>
<tf.RaggedTensor [[0, 1], [0, 1, 2]]>
<tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]>

batch_size A tf.int64 scalar tf.Tensor, representing the number of consecutive elements of this dataset to combine in a single batch.
drop_remainder (Optional.) A tf.bool scalar tf.Tensor, representing whether the last batch should be dropped in the case it has fewer than batch_size elements; the default behavior is not to drop the smaller batch.
row_splits_dtype The dtype that should be used for the row_splits of any new ragged tensors. Existing tf.RaggedTensor elements do not have their row_splits dtype changed.

Dataset A Dataset.