![]() |
Pooling head used for EncT5 style models.
tfm.nlp.layers.PerQueryDenseHead(
num_queries: int,
features: int,
use_bias: bool = False,
kernel_initializer: str = 'glorot_uniform',
**kwargs
)
This module projects each query to use a different projection.
For a input shape= [bs, num_queries, hidden_size], it projects each query to (features). Ending up with shape= [bs, num_queries, features].
For example, for classification with a few classes, one may use num_queries as 1 and features as number of classes. For multilabel classification, one may use num_queries as number of classes and features as 2. So each query represents a binary classification of one label.
Methods
call
call(
inputs: tf.Tensor
) -> tf.Tensor
Implements call().
Args | |
---|---|
inputs
|
a rank-3 Tensor of shape= [bs, num_queries, hidden_size]. |
Returns | |
---|---|
A Tensor, shape= [batch size, num_queries, features]. |