View source on GitHub |
Utility tf.function
that converts a 2d padded tensor to ragged.
@tf.function
tfq.padded_to_ragged2d( masked_state )
Convert a [batch, dim, dim] tf.Tensor
padded with -2 to a
tf.RaggedTensor
using 2d boolean masking.
Args | |
---|---|
masked_state
|
tf.Tensor of rank 3 with -2 padding.
|
Returns | |
---|---|
state_ragged
|
tf.RaggedTensor of rank 3 with no -2 padding where the
outer most dimensions are now ragged instead of padded.
|