![]() |
Two-stream relative self-attention for XLNet.
Inherits From: MultiHeadRelativeAttention
tfm.nlp.layers.TwoStreamRelativeAttention(
kernel_initializer='variance_scaling', **kwargs
)
In XLNet, each token has two associated vectors at each self-attention layer, the content stream (h) and the query stream (g).
The content stream is the self-attention stream as in Transformer XL and represents the context and content (the token itself).
The query stream only has access to contextual information and the position, but not the content.
This layer shares the same build signature as
tf.keras.layers.MultiHeadAttention
but has different input/output
projections.
Methods
call
call(
content_stream,
content_attention_bias,
positional_attention_bias,
query_stream,
relative_position_encoding,
target_mapping=None,
segment_matrix=None,
segment_encoding=None,
segment_attention_bias=None,
state=None,
content_attention_mask=None,
query_attention_mask=None
)
Compute multi-head relative attention over inputs.
Size glossary | |
---|---|
|
Args | |
---|---|
content_stream
|
The content representation, commonly referred to as h. This serves a similar role to the standard hidden states in Transformer-XL. |
content_attention_bias
|
A trainable bias parameter added to the query head when calculating the content-based attention score. |
positional_attention_bias
|
A trainable bias parameter added to the query head when calculating the position-based attention score. |
query_stream
|
The query representation, commonly referred to as g. This only has access to contextual information and position, but not content. If not provided, then this is MultiHeadRelativeAttention with self-attention. |
relative_position_encoding
|
relative positional encoding for key and value. |
target_mapping
|
Optional Tensor representing the target mapping used in
partial prediction.
|
segment_matrix
|
Optional Tensor representing segmentation IDs used in
XLNet.
|
segment_encoding
|
Optional Tensor representing the segmentation encoding
as used in XLNet.
|
segment_attention_bias
|
Optional trainable bias parameter added to the query head when calculating the segment-based attention score. |
state
|
(default None) optional state. If passed, this is also attended over as in TransformerXL and XLNet. |
content_attention_mask
|
(default None) Optional mask that is added to content attention logits. If state is not None, the mask source sequence dimension should extend M. |
query_attention_mask
|
(default None) Optional mask that is added to query attention logits. If state is not None, the mask source sequence dimension should extend M. |
Returns | |
---|---|
content_attention_output, query_attention_output: the results of the
computation, both of shape [B, T, E]. T is for target sequence shapes,
E is the query input last dimension if output_shape is None .
Otherwise, the multi-head outputs are projected to the shape specified
by output_shape .
|
compute_attention
compute_attention(
query,
key,
value,
position,
content_attention_bias,
positional_attention_bias,
segment_matrix=None,
segment_encoding=None,
segment_attention_bias=None,
attention_mask=None
)
Computes the attention.
This function defines the computation inside call
with projected
multihead Q, K, V, R inputs.
Args | |
---|---|
query
|
Projected query Tensor of shape [B, T, N, key_dim] .
|
key
|
Projected key Tensor of shape [B, S + M, N, key_dim] .
|
value
|
Projected value Tensor of shape [B, S + M, N, key_dim] .
|
position
|
Projected position Tensor of shape [B, L, N, key_dim] .
|
content_attention_bias
|
Trainable bias parameter added to the query head when calculating the content-based attention score. |
positional_attention_bias
|
Trainable bias parameter added to the query head when calculating the position-based attention score. |
segment_matrix
|
Optional Tensor representing segmentation IDs used in
XLNet.
|
segment_encoding
|
Optional trainable Tensor representing the
segmentation encoding as used in XLNet.
|
segment_attention_bias
|
Optional trainable bias parameter added to the query had when calculating the segment-based attention score used in XLNet. |
attention_mask
|
(default None) Optional mask that is added to attention logits. If state is not None, the mask source sequence dimension should extend M. |
Returns | |
---|---|
attention_output
|
Multi-headed output of attention computation of shape
[B, S, N, key_dim] .
|