View source on GitHub |
Zeroes out IDs of classes not in allowed_class_ids.
tfm.vision.maskrcnn.zero_out_disallowed_class_ids(
batch_class_ids: tf.Tensor, allowed_class_ids: List[int]
)
Args | |
---|---|
batch_class_ids
|
A [batch_size, num_instances] int tensor of input class IDs. |
allowed_class_ids
|
A python list of class IDs which we want to allow. |
Returns | |
---|---|
filtered_class_ids
|
A [batch_size, num_instances] int tensor with any class ID not in allowed_class_ids set to 0. |