tfm.nlp.layers.PerQueryDenseHead

Pooling head used for EncT5 style models.

This module projects each query to use a different projection.

For a input shape= [bs, num_queries, hidden_size], it projects each query to (features). Ending up with shape= [bs, num_queries, features].

For example, for classification with a few classes, one may use num_queries as 1 and features as number of classes. For multilabel classification, one may use num_queries as number of classes and features as 2. So each query represents a binary classification of one label.

num_queries number of queries (the learnable embeddings in the input sequences) from the decoder.
features int with numbers of output features. Each query with be projected to this number with a different projection.
use_bias whether to add a bias to the output.
kernel_initializer Initializer for dense layer kernels.
**kwargs Keyword arguments.

Methods

call

View source

Implements call().

Args
inputs a rank-3 Tensor of shape= [bs, num_queries, hidden_size].

Returns
A Tensor, shape= [batch size, num_queries, features].