An ItemSelector implementation that randomly selects items in a batch.
text.RandomItemSelector(
    max_selections_per_batch,
    selection_rate,
    unselectable_ids=None,
    shuffle_fn=None
)
Used in the notebooks
RandomItemSelector randomly selects items in a batch subject to
restrictions given (max_selections_per_batch, selection_rate and
unselectable_ids).
Example:
vocab = ["[UNK]", "[MASK]", "[RANDOM]", "[CLS]", "[SEP]",
         "abc", "def", "ghi"]
# Note that commonly in masked language model work, there are
# special tokens we don't want to mask, like CLS, SEP, and probably
# any OOV (out-of-vocab) tokens here called UNK.
# Note that if e.g. there are bucketed OOV tokens in the code,
# that might be a use case for overriding `get_selectable()` to
# exclude a range of IDs rather than enumerating them.
tf.random.set_seed(1234)
selector = tf_text.RandomItemSelector(
    max_selections_per_batch=2,
    selection_rate=0.2,
    unselectable_ids=[0, 3, 4])  # indices of UNK, CLS, SEP
selection = selector.get_selection_mask(
    tf.ragged.constant([[3, 5, 7, 7], [4, 6, 7, 5]]), axis=1)
print(selection)
<tf.RaggedTensor [[False, False, False, True], [False, False, True, False]]>
The selection has skipped the first elements (the CLS and SEP token codings)
and picked random elements from the other elements of the segments -- if
run with a different random seed the selections might be different.
Args | 
max_selections_per_batch
 | 
An int of the max number of items to mask out.
 | 
selection_rate
 | 
The rate at which items are randomly selected.
 | 
unselectable_ids
 | 
(optional) A list of python ints or 1D Tensor of ints
which are ids that will be not be masked.
 | 
shuffle_fn
 | 
(optional) A function that shuffles a 1D Tensor. Default
uses tf.random.shuffle.
 | 
Attributes | 
max_selections_per_batch
 | 
 | 
selection_rate
 | 
 | 
shuffle_fn
 | 
 | 
unselectable_ids
 | 
 | 
Methods
get_selectable
View source
get_selectable(
    input_ids, axis
)
Return a boolean mask of items that can be chosen for selection.
The default implementation marks all items whose IDs are not in the
unselectable_ids list. This can be overridden if there is a need for
a more complex or algorithmic approach for selectability.
| Args | 
input_ids
 | 
a RaggedTensor.
 | 
axis
 | 
axis to apply selection on.
 | 
| Returns | 
a RaggedTensor with dtype of bool and with shape
input_ids.shape[:axis]. Its values are True if the
corresponding item (or broadcasted subitems) should be considered for
masking. In the default implementation, all input_ids items that are not
listed in unselectable_ids (from the class arg) are considered
selectable.
 | 
get_selection_mask
View source
get_selection_mask(
    input_ids, axis
)
Returns a mask of items that have been selected.
The default implementation simply returns all items not excluded by
get_selectable.
| Args | 
input_ids
 | 
A RaggedTensor.
 | 
axis
 | 
(optional) An int detailing the dimension to apply selection on.
Default is the 1st dimension.
 | 
| Returns | 
a RaggedTensor with shape input_ids.shape[:axis]. Its values are True
if the corresponding item (or broadcasted subitems) should be selected.
 |