Thanks for tuning in to Google I/O. View all sessions on demandWatch on demand

tfm.nlp.layers.TwoStreamRelativeAttention

Two-stream relative self-attention for XLNet.

Inherits From: MultiHeadRelativeAttention

In XLNet, each token has two associated vectors at each self-attention layer, the content stream (h) and the query stream (g).

The content stream is the self-attention stream as in Transformer XL and represents the context and content (the token itself).

The query stream only has access to contextual information and the position, but not the content.

This layer shares the same build signature as tf.keras.layers.MultiHeadAttention but has different input/output projections.

content_stream Tensor of shape [B, T, dim].
content_attention_bias Bias Tensor for content based attention of shape [num_heads, dim].
positional_attention_bias Bias Tensor for position based attention of shape [num_heads, dim].
query_stream Tensor of shape [B, P, dim].
target_mapping Tensor of shape [B, P, S].
relative_position_encoding Relative positional encoding Tensor of shape [B, L, dim].
segment_matrix Optional Tensor representing segmentation IDs used in XLNet of shape [B, S, S + M].
segment_encoding Optional Tensor representing the segmentation encoding as used in XLNet of shape [2, num_heads, dim].
segment_attention_bias Optional trainable bias parameter added to the query had when calculating the segment-based attention score used in XLNet of shape [num_heads, dim].
state Optional Tensor of shape [B, M, E] where M is the length of the state or memory. If passed, this is also attended over as in Transformer XL.
content_attention_mask a boolean mask of shape [B, T, S] that prevents attention to certain positions for content attention computation.
query_attention_mask a boolean mask of shape [B, T, S] that prevents attention to certain position for query attention computation.

Methods

call

View source

Compute multi-head relative attention over inputs.

Size glossary

  • Number of heads (H): the number of attention heads.
  • Value size (V): the size of each value embedding per head.
  • Key size (K): the size of each key embedding per head. Equally, the size of each query embedding per head. Typically K <= V.
  • Number of predictions (P): the number of predictions.
  • Batch dimensions (B).
  • Query (target) attention axes shape (T).
  • Value (source) attention axes shape (S), the rank must match the target.
  • Encoding length (L): The relative positional encoding length.

Args
content_stream The content representation, commonly referred to as h. This serves a similar role to the standard hidden states in Transformer-XL.
content_attention_bias A trainable bias parameter added to the query head when calculating the content-based attention score.
positional_attention_bias A trainable bias parameter added to the query head when calculating the position-based attention score.
query_stream The query representation, commonly referred to as g. This only has access to contextual information and position, but not content. If not provided, then this is MultiHeadRelativeAttention with self-attention.
relative_position_encoding relative positional encoding for key and value.
target_mapping Optional Tensor representing the target mapping used in partial prediction.
segment_matrix Optional Tensor representing segmentation IDs used in XLNet.
segment_encoding Optional Tensor representing the segmentation encoding as used in XLNet.
segment_attention_bias Optional trainable bias parameter added to the query head when calculating the segment-based attention score.
state (default None) optional state. If passed, this is also attended over as in TransformerXL and XLNet.
content_attention_mask (default None) Optional mask that is added to content attention logits. If state is not None, the mask source sequence dimension should extend M.
query_attention_mask (default None) Optional mask that is added to query attention logits. If state is not None, the mask source sequence dimension should extend M.

Returns
content_attention_output, query_attention_output: the results of the computation, both of shape [B, T, E]. T is for target sequence shapes, E is the query input last dimension if output_shape is None. Otherwise, the multi-head outputs are projected to the shape specified by output_shape.

compute_attention

View source

Computes the attention.

This function defines the computation inside call with projected multihead Q, K, V, R inputs.

Args
query Projected query Tensor of shape [B, T, N, key_dim].
key Projected key Tensor of shape [B, S + M, N, key_dim].
value Projected value Tensor of shape [B, S + M, N, key_dim].
position Projected position Tensor of shape [B, L, N, key_dim].
content_attention_bias Trainable bias parameter added to the query head when calculating the content-based attention score.
positional_attention_bias Trainable bias parameter added to the query head when calculating the position-based attention score.
segment_matrix Optional Tensor representing segmentation IDs used in XLNet.
segment_encoding Optional trainable Tensor representing the segmentation encoding as used in XLNet.
segment_attention_bias Optional trainable bias parameter added to the query had when calculating the segment-based attention score used in XLNet.
attention_mask (default None) Optional mask that is added to attention logits. If state is not None, the mask source sequence dimension should extend M.

Returns
attention_output Multi-headed output of attention computation of shape [B, S, N, key_dim].