|View source on GitHub|
Compat aliases for migration
See Migration guide for more details.
tf.keras.layers.MultiHeadAttention( num_heads, key_dim, value_dim=None, dropout=0.0, use_bias=True, output_shape=None, attention_axes=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs )
This is an implementation of multi-headed attention based on "Attention
is all you Need". If
value are the same, then
this is self-attention. Each timestep in
query attends to the
corresponding sequence in
key, and returns a fixed-width vector.
This layer first projects
value. These are
(effectively) a list of tensors of length
num_attention_heads, where the
corresponding shapes are [batch_size,
Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor.
Finally, the result tensor with the last dimension as value_dim can take an linear projection and return.
Performs 1D cross-attention over two sequence inputs with an attention mask. Returns the additional attention weights over heads.
layer = MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.Input(shape=[8, 16])
source = tf.keras.Input(shape=[4, 16])
output_tensor, weights = layer(target, source,
(None, 8, 16)
(None, 2, 8, 4)
Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
layer = MultiHeadAttention(num_heads=2, key_dim=2, attention_axes=(2, 3))
input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
output_tensor = layer(input_tensor, input_tensor)
(None, 5, 3, 4, 16)
||Number of attention heads.|
||Size of each attention head for query and key.|
||Size of each attention head for value.|
||Boolean, whether the dense layers use bias vectors/matrices.|
||The expected shape of an output tensor, besides the batch and sequence dims. If not specified, projects back to the key feature dim.|
axes over which the attention is applied. |