tf.keras.layers.MultiHeadAttention

MultiHeadAttention layer.

Inherits From: Layer, Module

This is an implementation of multi-headed attention based on "Attention is all you Need". If query, key, value are the same, then this is self-attention. Each timestep in query attends to the corresponding sequence in key, and returns a fixed-width vector.

This layer first projects query, key and value. These are (effectively) a list of tensors of length num_attention_heads, where the corresponding shapes are [batch_size, , key_dim], [batch_size, , key_dim], [batch_size, , value_dim].

Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor.

Finally, the result tensor with the last dimension as value_dim can take an linear projection and return.

Examples:

Performs 1D cross-attention over two sequence inputs with an attention mask. Returns the additional attention weights over heads.

layer = MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.Input(shape=[8, 16])
source = tf.keras.Input(shape=[4, 16])
output_tensor, weights = layer(target, source,
                               return_attention_scores=True)
print(output_tensor.shape)
(None, 8, 16)
print(weights.shape)
(None, 2, 8, 4)

Performs 2D self-attention over a 5D input tensor on axes 2 and 3.

layer = MultiHeadAttention(num_heads=2, key_dim=2, attention_axes=(2, 3))
input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
output_tensor = layer(input_tensor, input_tensor)
print(output_tensor.shape)
(None, 5, 3, 4, 16)

num_heads Number of attention heads.
key_dim Size of each attention head for query and key.
value_dim Size of each attention head for value.
dropout Dropout probability.
use_bias Boolean, whether the dense layers use bias vectors/matrices.
output_shape The expected shape of an output tensor, besides the batch and sequence dims. If not specified, projects back to the key feature dim.
attention_axes axes over which the attention is applied.