View source on GitHub
|
MultiHeadAttention layer.
tf.keras.layers.MultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
)
This is an implementation of multi-headed attention as described in the paper
"Attention is all you Need" (Vaswani et al., 2017).
If query, key, value are the same, then
this is self-attention. Each timestep in query attends to the
corresponding sequence in key, and returns a fixed-width vector.
This layer first projects query, key and value. These are
(effectively) a list of tensors of length num_attention_heads, where the
corresponding shapes are (batch_size, <query dimensions>, key_dim),
(batch_size, <key/value dimensions>, key_dim),
(batch_size, <key/value dimensions>, value_dim).
Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor.
Finally, the result tensor with the last dimension as value_dim can take an linear projection and return.
When using MultiHeadAttention inside a custom Layer, the custom Layer must
implement build() and call MultiHeadAttention's _build_from_signature().
This enables weights to be restored correctly when the model is loaded.
when used in a custom Layer.
Examples:
Performs 1D cross-attention over two sequence inputs with an attention mask. Returns the additional attention weights over heads.
layer = MultiHeadAttention(num_heads=2, key_dim=2)target = tf.keras.Input(shape=[8, 16])source = tf.keras.Input(shape=[4, 16])output_tensor, weights = layer(target, source,return_attention_scores=True)print(output_tensor.shape)(None, 8, 16)print(weights.shape)(None, 2, 8, 4)
Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
layer = MultiHeadAttention(num_heads=2, key_dim=2, attention_axes=(2, 3))input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])output_tensor = layer(input_tensor, input_tensor)print(output_tensor.shape)(None, 5, 3, 4, 16)
View source on GitHub