When it calculates attention logits, position encoding is projected to form
relative keys. The logits are composed by shifted relative logits and content
logits.
Call args
query
Query Tensor of shape [B, T, dim].
value
Value Tensor of shape [B, S, dim].
content_attention_bias
Bias Tensor for content based attention of shape
[num_heads, dim].
positional_attention_bias
Bias Tensor for position based attention of
shape [num_heads, dim].
key
Optional key Tensor of shape [B, S, dim]. If not given, will use
value for both key and value, which is the most common case.
relative_position_encoding
Relative positional encoding Tensor of shape
[B, L, dim].
segment_matrix
Optional Tensor representing segmentation IDs used in
XLNet of shape [B, S, S + M].
segment_encoding
Optional Tensor representing the segmentation encoding
as used in XLNet of shape [2, num_heads, dim].
segment_attention_bias
Optional trainable bias parameter added to the query
had when calculating the segment-based attention score used in XLNet of
shape [num_heads, dim].
state
Optional Tensor of shape [B, M, E] where M is the length of the
state or memory. If passed, this is also attended over as in Transformer
XL.
attention_mask
A boolean mask of shape [B, T, S] that prevents attention
to certain positions.
Attributes
kernel_initializer
The kernel initializer. Defaults to variance_scaling.
Compute multi-head relative attention over inputs.
Size glossary
Number of heads (H): the number of attention heads.
Value size (V): the size of each value embedding per head.
Key size (K): the size of each key embedding per head. Equally, the size
of each query embedding per head. Typically K <= V.
Batch dimensions (B).
Query (target) attention axes shape (T).
Value (source) attention axes shape (S), the rank must match the target.
Encoding length (L): The relative positional encoding length.
Args
query
attention input.
value
attention input.
content_attention_bias
A trainable bias parameter added to the query head
when calculating the content-based attention score.
positional_attention_bias
A trainable bias parameter added to the query
head when calculating the position-based attention score.
key
attention input.
relative_position_encoding
relative positional encoding for key and
value.
segment_matrix
Optional Tensor representing segmentation IDs used in
XLNet.
segment_encoding
Optional Tensor representing the segmentation encoding
as used in XLNet.
segment_attention_bias
Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
state
(default None) optional state. If passed, this is also attended
over as in TransformerXL.
attention_mask
(default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should
extend M.
Returns
attention_output
The result of the computation, of shape [B, T, E],
where T is for target sequence shapes and E is the query input last
dimension if output_shape is None. Otherwise, the multi-head outputs
are projected to the shape specified by output_shape.