![]() |
![]() |
![]() |
![]() |
This tutorial trains a transformer model to translate a Portuguese to English dataset.
This is an advanced example that assumes knowledge of text generation and attention.
This tutorial demonstrates how to build a transformer model and most of its components from scratch using low-level TensorFlow and Keras functionalities. Some of this could be minimized if you took advantage of built-in APIs like tf.keras.layers.MultiHeadAttention
.
The core idea behind a transformer model is self-attention—the ability to attend to different positions of the input sequence to compute a representation of that sequence. Transformer creates stacks of self-attention layers and is explained below in the sections Scaled dot product attention and Multi-head attention.
A transformer model handles variable-sized input using stacks of self-attention layers instead of RNNs or CNNs. This general architecture has a number of advantages:
- It makes no assumptions about the temporal/spatial relationships across the data. This is ideal for processing a set of objects (for example, StarCraft units).
- Layer outputs can be calculated in parallel, instead of a series like an RNN.
- Distant items can affect each other's output without passing through many RNN-steps, or convolution layers (see Scene Memory Transformer for example).
- It can learn long-range dependencies. This is a challenge in many sequence tasks.
The downsides of this architecture are:
- For a time-series, the output for a time-step is calculated from the entire history instead of only the inputs and current hidden-state. This may be less efficient.
- If the input does have a temporal/spatial relationship, like text, some positional encoding must be added or the model will effectively see a bag of words.
After training the model in this notebook, you will be able to input a Portuguese sentence and return the English translation.
Setup
pip install tensorflow_datasets
pip install -U 'tensorflow-text==2.8.*'
import logging
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import tensorflow as tf
# Import tf_text to load the ops used by the tokenizer saved model
import tensorflow_text # pylint: disable=unused-import
logging.getLogger('tensorflow').setLevel(logging.ERROR) # suppress warnings
Download the Dataset
Use TensorFlow datasets to load the Portuguese-English translation dataset from the TED Talks Open Translation Project.
This dataset contains approximately 50000 training examples, 1100 validation examples, and 2000 test examples.
examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True,
as_supervised=True)
train_examples, val_examples = examples['train'], examples['validation']
The tf.data.Dataset
object returned by TensorFlow datasets yields pairs of text examples:
for pt_examples, en_examples in train_examples.batch(3).take(1):
for pt in pt_examples.numpy():
print(pt.decode('utf-8'))
print()
for en in en_examples.numpy():
print(en.decode('utf-8'))
e quando melhoramos a procura , tiramos a única vantagem da impressão , que é a serendipidade . mas e se estes fatores fossem ativos ? mas eles não tinham a curiosidade de me testar . and when you improve searchability , you actually take away the one advantage of print , which is serendipity . but what if it were active ? but they did n't test for curiosity .
Text tokenization & detokenization
You can't train a model directly on text. The text needs to be converted to some numeric representation first. Typically, you convert the text to sequences of token IDs, which are used as indices into an embedding.
One popular implementation is demonstrated in the Subword tokenizer tutorial builds subword tokenizers (text.BertTokenizer
) optimized for this dataset and exports them in a saved_model.
Download and unzip and import the saved_model
:
model_name = 'ted_hrlr_translate_pt_en_converter'
tf.keras.utils.get_file(
f'{model_name}.zip',
f'https://storage.googleapis.com/download.tensorflow.org/models/{model_name}.zip',
cache_dir='.', cache_subdir='', extract=True
)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/ted_hrlr_translate_pt_en_converter.zip 188416/184801 [==============================] - 0s 0us/step 196608/184801 [===============================] - 0s 0us/step './ted_hrlr_translate_pt_en_converter.zip'
tokenizers = tf.saved_model.load(model_name)
The tf.saved_model
contains two text tokenizers, one for English and one for Portuguese. Both have the same methods:
[item for item in dir(tokenizers.en) if not item.startswith('_')]
['detokenize', 'get_reserved_tokens', 'get_vocab_path', 'get_vocab_size', 'lookup', 'tokenize', 'tokenizer', 'vocab']
The tokenize
method converts a batch of strings to a padded-batch of token IDs. This method splits punctuation, lowercases and unicode-normalizes the input before tokenizing. That standardization is not visible here because the input data is already standardized.
for en in en_examples.numpy():
print(en.decode('utf-8'))
and when you improve searchability , you actually take away the one advantage of print , which is serendipity . but what if it were active ? but they did n't test for curiosity .
encoded = tokenizers.en.tokenize(en_examples)
for row in encoded.to_list():
print(row)
[2, 72, 117, 79, 1259, 1491, 2362, 13, 79, 150, 184, 311, 71, 103, 2308, 74, 2679, 13, 148, 80, 55, 4840, 1434, 2423, 540, 15, 3] [2, 87, 90, 107, 76, 129, 1852, 30, 3] [2, 87, 83, 149, 50, 9, 56, 664, 85, 2512, 15, 3]
The detokenize
method attempts to convert these token IDs back to human readable text:
round_trip = tokenizers.en.detokenize(encoded)
for line in round_trip.numpy():
print(line.decode('utf-8'))
and when you improve searchability , you actually take away the one advantage of print , which is serendipity . but what if it were active ? but they did n ' t test for curiosity .
The lower level lookup
method converts from token-IDs to token text:
tokens = tokenizers.en.lookup(encoded)
tokens
<tf.RaggedTensor [[b'[START]', b'and', b'when', b'you', b'improve', b'search', b'##ability', b',', b'you', b'actually', b'take', b'away', b'the', b'one', b'advantage', b'of', b'print', b',', b'which', b'is', b's', b'##ere', b'##nd', b'##ip', b'##ity', b'.', b'[END]'] , [b'[START]', b'but', b'what', b'if', b'it', b'were', b'active', b'?', b'[END]'] , [b'[START]', b'but', b'they', b'did', b'n', b"'", b't', b'test', b'for', b'curiosity', b'.', b'[END]'] ]>
Here you can see the "subword" aspect of the tokenizers. The word "searchability" is decomposed into "search ##ability" and the word "serendipity" into "s ##ere ##nd ##ip ##ity"
Now take a minute to investigate the distribution of tokens per example in the dataset:
lengths = []
for pt_examples, en_examples in train_examples.batch(1024):
pt_tokens = tokenizers.en.tokenize(pt_examples)
lengths.append(pt_tokens.row_lengths())
en_tokens = tokenizers.en.tokenize(en_examples)
lengths.append(en_tokens.row_lengths())
print('.', end='', flush=True)
...................................................
all_lengths = np.concatenate(lengths)
plt.hist(all_lengths, np.linspace(0, 500, 101))
plt.ylim(plt.ylim())
max_length = max(all_lengths)
plt.plot([max_length, max_length], plt.ylim())
plt.title(f'Max tokens per example: {max_length}');
MAX_TOKENS = 128
Setup input pipeline
To build an input pipeline suitable for training define some functions to transform the dataset.
Define a function to drop the examples longer than MAX_TOKENS
:
def filter_max_tokens(pt, en):
num_tokens = tf.maximum(tf.shape(pt)[1],tf.shape(en)[1])
return num_tokens < MAX_TOKENS
Define a function that tokenizes the batches of raw text:
def tokenize_pairs(pt, en):
pt = tokenizers.pt.tokenize(pt)
# Convert from ragged to dense, padding with zeros.
pt = pt.to_tensor()
en = tokenizers.en.tokenize(en)
# Convert from ragged to dense, padding with zeros.
en = en.to_tensor()
return pt, en
Here's a simple input pipeline that processes, shuffles and batches the data:
BUFFER_SIZE = 20000
BATCH_SIZE = 64
def make_batches(ds):
return (
ds
.cache()
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
.map(tokenize_pairs, num_parallel_calls=tf.data.AUTOTUNE)
.filter(filter_max_tokens)
.prefetch(tf.data.AUTOTUNE))
train_batches = make_batches(train_examples)
val_batches = make_batches(val_examples)
Positional encoding
Attention layers see their input as a set of vectors, with no sequential order. This model also doesn't contain any recurrent or convolutional layers. Because of this a "positional encoding" is added to give the model some information about the relative position of the tokens in the sentence.
The positional encoding vector is added to the embedding vector. Embeddings represent a token in a d-dimensional space where tokens with similar meaning will be closer to each other. But the embeddings do not encode the relative position of tokens in a sentence. So after adding the positional encoding, tokens will be closer to each other based on the similarity of their meaning and their position in the sentence, in the d-dimensional space.
The formula for calculating the positional encoding is as follows:
\[\Large{PE_{(pos, 2i)} = \sin(pos / 10000^{2i / d_{model} })} \]
\[\Large{PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i / d_{model} })} \]
def get_angles(pos, i, d_model):
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
return pos * angle_rates
def positional_encoding(position, d_model):
angle_rads = get_angles(np.arange(position)[:, np.newaxis],
np.arange(d_model)[np.newaxis, :],
d_model)
# apply sin to even indices in the array; 2i
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
# apply cos to odd indices in the array; 2i+1
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
pos_encoding = angle_rads[np.newaxis, ...]
return tf.cast(pos_encoding, dtype=tf.float32)
n, d = 2048, 512
pos_encoding = positional_encoding(n, d)
print(pos_encoding.shape)
pos_encoding = pos_encoding[0]
# Juggle the dimensions for the plot
pos_encoding = tf.reshape(pos_encoding, (n, d//2, 2))
pos_encoding = tf.transpose(pos_encoding, (2, 1, 0))
pos_encoding = tf.reshape(pos_encoding, (d, n))
plt.pcolormesh(pos_encoding, cmap='RdBu')
plt.ylabel('Depth')
plt.xlabel('Position')
plt.colorbar()
plt.show()
(1, 2048, 512)
Masking
Mask all the pad tokens in the batch of sequence. It ensures that the model does not treat padding as the input. The mask indicates where pad value 0
is present: it outputs a 1
at those locations, and a 0
otherwise.
def create_padding_mask(seq):
seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
# add extra dimensions to add the padding
# to the attention logits.
return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len)
x = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
create_padding_mask(x)
<tf.Tensor: shape=(3, 1, 1, 5), dtype=float32, numpy= array([[[[0., 0., 1., 1., 0.]]], [[[0., 0., 0., 1., 1.]]], [[[1., 1., 1., 0., 0.]]]], dtype=float32)>
The look-ahead mask is used to mask the future tokens in a sequence. In other words, the mask indicates which entries should not be used.
This means that to predict the third token, only the first and second token will be used. Similarly to predict the fourth token, only the first, second and the third tokens will be used and so on.
def create_look_ahead_mask(size):
mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
return mask # (seq_len, seq_len)
x = tf.random.uniform((1, 3))
temp = create_look_ahead_mask(x.shape[1])
temp
<tf.Tensor: shape=(3, 3), dtype=float32, numpy= array([[0., 1., 1.], [0., 0., 1.], [0., 0., 0.]], dtype=float32)>
# def create_look_ahead_mask(size):
# n = int(size * (size+1) / 2)
# mask = tfp.math.fill_triangular(tf.ones((n,), dtype=tf.int32), upper=False)
Scaled dot product attention
The attention function used by a transformer takes three inputs: Q (query), K (key), V (value). The equation used to calculate the attention weights is:
\[\Large{Attention(Q, K, V) = softmax_k\left(\frac{QK^T}{\sqrt{d_k} }\right) V} \]
The dot-product attention is scaled by a factor of square root of the depth. This is done because for large values of depth, the dot product grows large in magnitude pushing the softmax function where it has small gradients resulting in a very hard softmax.
For example, consider that Q
and K
have a mean of 0 and variance of 1. Their matrix multiplication will have a mean of 0 and variance of dk
. So the square root of dk
is used for scaling, so you get a consistent variance regardless of the value of dk
. If the variance is too low the output may be too flat to optimize effectively. If the variance is too high the softmax may saturate at initialization making it difficult to learn.
The mask is multiplied with -1e9 (close to negative infinity). This is done because the mask is summed with the scaled matrix multiplication of Q and K and is applied immediately before a softmax. The goal is to zero out these cells, and large negative inputs to softmax are near zero in the output.
def scaled_dot_product_attention(q, k, v, mask):
"""Calculate the attention weights.
q, k, v must have matching leading dimensions.
k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
The mask has different shapes depending on its type(padding or look ahead)
but it must be broadcastable for addition.
Args:
q: query shape == (..., seq_len_q, depth)
k: key shape == (..., seq_len_k, depth)
v: value shape == (..., seq_len_v, depth_v)
mask: Float tensor with shape broadcastable
to (..., seq_len_q, seq_len_k). Defaults to None.
Returns:
output, attention_weights
"""
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
# scale matmul_qk
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# add the mask to the scaled tensor.
if mask is not None:
scaled_attention_logits += (mask * -1e9)
# softmax is normalized on the last axis (seq_len_k) so that the scores
# add up to 1.
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
As the softmax normalization is done on K, its values decide the amount of importance given to Q.
The output represents the multiplication of the attention weights and the V (value) vector. This ensures that the tokens you want to focus on are kept as-is and the irrelevant tokens are flushed out.
def print_out(q, k, v):
temp_out, temp_attn = scaled_dot_product_attention(
q, k, v, None)
print('Attention weights are:')
print(temp_attn)
print('Output is:')
print(temp_out)
np.set_printoptions(suppress=True)
temp_k = tf.constant([[10, 0, 0],
[0, 10, 0],
[0, 0, 10],
[0, 0, 10]], dtype=tf.float32) # (4, 3)
temp_v = tf.constant([[1, 0],
[10, 0],
[100, 5],
[1000, 6]], dtype=tf.float32) # (4, 2)
# This `query` aligns with the second `key`,
# so the second `value` is returned.
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32) # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are: tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32) Output is: tf.Tensor([[10. 0.]], shape=(1, 2), dtype=float32)
# This query aligns with a repeated key (third and fourth),
# so all associated values get averaged.
temp_q = tf.constant([[0, 0, 10]], dtype=tf.float32) # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are: tf.Tensor([[0. 0. 0.5 0.5]], shape=(1, 4), dtype=float32) Output is: tf.Tensor([[550. 5.5]], shape=(1, 2), dtype=float32)
# This query aligns equally with the first and second key,
# so their values get averaged.
temp_q = tf.constant([[10, 10, 0]], dtype=tf.float32) # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are: tf.Tensor([[0.5 0.5 0. 0. ]], shape=(1, 4), dtype=float32) Output is: tf.Tensor([[5.5 0. ]], shape=(1, 2), dtype=float32)
Pass all the queries together.
temp_q = tf.constant([[0, 0, 10],
[0, 10, 0],
[10, 10, 0]], dtype=tf.float32) # (3, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are: tf.Tensor( [[0. 0. 0.5 0.5] [0. 1. 0. 0. ] [0.5 0.5 0. 0. ]], shape=(3, 4), dtype=float32) Output is: tf.Tensor( [[550. 5.5] [ 10. 0. ] [ 5.5 0. ]], shape=(3, 2), dtype=float32)
Multi-head attention
Multi-head attention consists of four parts:
- Linear layers.
- Scaled dot-product attention.
- Final linear layer.
Each multi-head attention block gets three inputs; Q (query), K (key), V (value). These are put through linear (Dense) layers before the multi-head attention function.
In the diagram above (K,Q,V)
are passed through sepearte linear (Dense
) layers for each attention head. For simplicity/efficiency the code below implements this using a single dense layer with num_heads
times as many outputs. The output is rearranged to a shape of (batch, num_heads, ...)
before applying the attention function.
The scaled_dot_product_attention
function defined above is applied in a single call, broadcasted for efficiency. An appropriate mask must be used in the attention step. The attention output for each head is then concatenated (using tf.transpose
, and tf.reshape
) and put through a final Dense
layer.
Instead of one single attention head, Q, K, and V are split into multiple heads because it allows the model to jointly attend to information from different representation subspaces at different positions. After the split each head has a reduced dimensionality, so the total computation cost is the same as a single head attention with full dimensionality.
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self,*, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = scaled_dot_product_attention(
q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
Create a MultiHeadAttention
layer to try out. At each location in the sequence, y
, the MultiHeadAttention
runs all 8 attention heads across all other locations in the sequence, returning a new vector of the same length at each location.
temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
y = tf.random.uniform((1, 60, 512)) # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, k=y, q=y, mask=None)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 8, 60, 60]))
Point wise feed forward network
Point wise feed forward network consists of two fully-connected layers with a ReLU activation in between.
def point_wise_feed_forward_network(d_model, dff):
return tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff)
tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model)
])
sample_ffn = point_wise_feed_forward_network(512, 2048)
sample_ffn(tf.random.uniform((64, 50, 512))).shape
TensorShape([64, 50, 512])
Encoder and decoder
A transformer model follows the same general pattern as a standard sequence to sequence with attention model.
- The input sentence is passed through
N
encoder layers that generates an output for each token in the sequence. - The decoder attends to the encoder's output and its own input (self-attention) to predict the next word.
Encoder layer
Each encoder layer consists of sublayers:
- Multi-head attention (with padding mask)
- Point wise feed forward networks.
Each of these sublayers has a residual connection around it followed by a layer normalization. Residual connections help in avoiding the vanishing gradient problem in deep networks.
The output of each sublayer is LayerNorm(x + Sublayer(x))
. The normalization is done on the d_model
(last) axis. There are N encoder layers in a transformer.
class EncoderLayer(tf.keras.layers.Layer):
def __init__(self,*, d_model, num_heads, dff, rate=0.1):
super(EncoderLayer, self).__init__()
self.mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
attn_output, _ = self.mha(x, x, x, mask) # (batch_size, input_seq_len, d_model)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(x + attn_output) # (batch_size, input_seq_len, d_model)
ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layernorm2(out1 + ffn_output) # (batch_size, input_seq_len, d_model)
return out2
sample_encoder_layer = EncoderLayer(d_model=512, num_heads=8, dff=2048)
sample_encoder_layer_output = sample_encoder_layer(
tf.random.uniform((64, 43, 512)), False, None)
sample_encoder_layer_output.shape # (batch_size, input_seq_len, d_model)
TensorShape([64, 43, 512])
Decoder layer
Each decoder layer consists of sublayers:
- Masked multi-head attention (with look ahead mask and padding mask)
- Multi-head attention (with padding mask). V (value) and K (key) receive the encoder output as inputs. Q (query) receives the output from the masked multi-head attention sublayer.
- Point wise feed forward networks
Each of these sublayers has a residual connection around it followed by a layer normalization. The output of each sublayer is LayerNorm(x + Sublayer(x))
. The normalization is done on the d_model
(last) axis.
There are a number of decoder layers in the model.
As Q receives the output from decoder's first attention block, and K receives the encoder output, the attention weights represent the importance given to the decoder's input based on the encoder's output. In other words, the decoder predicts the next token by looking at the encoder output and self-attending to its own output. See the demonstration above in the scaled dot product attention section.
class DecoderLayer(tf.keras.layers.Layer):
def __init__(self,*, d_model, num_heads, dff, rate=0.1):
super(DecoderLayer, self).__init__()
self.mha1 = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
self.mha2 = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
self.dropout3 = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training,
look_ahead_mask, padding_mask):
# enc_output.shape == (batch_size, input_seq_len, d_model)
attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask) # (batch_size, target_seq_len, d_model)
attn1 = self.dropout1(attn1, training=training)
out1 = self.layernorm1(attn1 + x)
attn2, attn_weights_block2 = self.mha2(
enc_output, enc_output, out1, padding_mask) # (batch_size, target_seq_len, d_model)
attn2 = self.dropout2(attn2, training=training)
out2 = self.layernorm2(attn2 + out1) # (batch_size, target_seq_len, d_model)
ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model)
ffn_output = self.dropout3(ffn_output, training=training)
out3 = self.layernorm3(ffn_output + out2) # (batch_size, target_seq_len, d_model)
return out3, attn_weights_block1, attn_weights_block2
sample_decoder_layer = DecoderLayer(d_model=512, num_heads=8, dff=2048)
sample_decoder_layer_output, _, _ = sample_decoder_layer(
tf.random.uniform((64, 50, 512)), sample_encoder_layer_output,
False, None, None)
sample_decoder_layer_output.shape # (batch_size, target_seq_len, d_model)
TensorShape([64, 50, 512])
Encoder
The Encoder
consists of:
- Input Embedding
- Positional Encoding
- N encoder layers
The input is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the encoder layers. The output of the encoder is the input to the decoder.
class Encoder(tf.keras.layers.Layer):
def __init__(self,*, num_layers, d_model, num_heads, dff, input_vocab_size,
rate=0.1):
super(Encoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
self.pos_encoding = positional_encoding(MAX_TOKENS, self.d_model)
self.enc_layers = [
EncoderLayer(d_model=d_model, num_heads=num_heads, dff=dff, rate=rate)
for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
seq_len = tf.shape(x)[1]
# adding embedding and position encoding.
x = self.embedding(x) # (batch_size, input_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x = self.enc_layers[i](x, training, mask)
return x # (batch_size, input_seq_len, d_model)
sample_encoder = Encoder(num_layers=2, d_model=512, num_heads=8,
dff=2048, input_vocab_size=8500)
temp_input = tf.random.uniform((64, 62), dtype=tf.int64, minval=0, maxval=200)
sample_encoder_output = sample_encoder(temp_input, training=False, mask=None)
print(sample_encoder_output.shape) # (batch_size, input_seq_len, d_model)
(64, 62, 512)
Decoder
The Decoder
consists of:
- Output Embedding
- Positional Encoding
- N decoder layers
The target is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the decoder layers. The output of the decoder is the input to the final linear layer.
class Decoder(tf.keras.layers.Layer):
def __init__(self,*, num_layers, d_model, num_heads, dff, target_vocab_size,
rate=0.1):
super(Decoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
self.pos_encoding = positional_encoding(MAX_TOKENS, d_model)
self.dec_layers = [
DecoderLayer(d_model=d_model, num_heads=num_heads, dff=dff, rate=rate)
for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training,
look_ahead_mask, padding_mask):
seq_len = tf.shape(x)[1]
attention_weights = {}
x = self.embedding(x) # (batch_size, target_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x, block1, block2 = self.dec_layers[i](x, enc_output, training,
look_ahead_mask, padding_mask)
attention_weights[f'decoder_layer{i+1}_block1'] = block1
attention_weights[f'decoder_layer{i+1}_block2'] = block2
# x.shape == (batch_size, target_seq_len, d_model)
return x, attention_weights
sample_decoder = Decoder(num_layers=2, d_model=512, num_heads=8,
dff=2048, target_vocab_size=8000)
temp_input = tf.random.uniform((64, 26), dtype=tf.int64, minval=0, maxval=200)
output, attn = sample_decoder(temp_input,
enc_output=sample_encoder_output,
training=False,
look_ahead_mask=None,
padding_mask=None)
output.shape, attn['decoder_layer2_block2'].shape
(TensorShape([64, 26, 512]), TensorShape([64, 8, 26, 62]))
Create the transformer model
A transformer consists of the encoder, decoder, and a final linear layer. The output of the decoder is the input to the linear layer and its output is returned.
class Transformer(tf.keras.Model):
def __init__(self,*, num_layers, d_model, num_heads, dff, input_vocab_size,
target_vocab_size, rate=0.1):
super().__init__()
self.encoder = Encoder(num_layers=num_layers, d_model=d_model,
num_heads=num_heads, dff=dff,
input_vocab_size=input_vocab_size, rate=rate)
self.decoder = Decoder(num_layers=num_layers, d_model=d_model,
num_heads=num_heads, dff=dff,
target_vocab_size=target_vocab_size, rate=rate)
self.final_layer = tf.keras.layers.Dense(target_vocab_size)
def call(self, inputs, training):
# Keras models prefer if you pass all your inputs in the first argument
inp, tar = inputs
padding_mask, look_ahead_mask = self.create_masks(inp, tar)
enc_output = self.encoder(inp, training, padding_mask) # (batch_size, inp_seq_len, d_model)
# dec_output.shape == (batch_size, tar_seq_len, d_model)
dec_output, attention_weights = self.decoder(
tar, enc_output, training, look_ahead_mask, padding_mask)
final_output = self.final_layer(dec_output) # (batch_size, tar_seq_len, target_vocab_size)
return final_output, attention_weights
def create_masks(self, inp, tar):
# Encoder padding mask (Used in the 2nd attention block in the decoder too.)
padding_mask = create_padding_mask(inp)
# Used in the 1st attention block in the decoder.
# It is used to pad and mask future tokens in the input received by
# the decoder.
look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
dec_target_padding_mask = create_padding_mask(tar)
look_ahead_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
return padding_mask, look_ahead_mask
sample_transformer = Transformer(
num_layers=2, d_model=512, num_heads=8, dff=2048,
input_vocab_size=8500, target_vocab_size=8000)
temp_input = tf.random.uniform((64, 38), dtype=tf.int64, minval=0, maxval=200)
temp_target = tf.random.uniform((64, 36), dtype=tf.int64, minval=0, maxval=200)
fn_out, _ = sample_transformer([temp_input, temp_target], training=False)
fn_out.shape # (batch_size, tar_seq_len, target_vocab_size)
TensorShape([64, 36, 8000])
Set hyperparameters
To keep this example small and relatively fast, the values for num_layers, d_model, dff
have been reduced.
The base model described in the paper used: num_layers=6, d_model=512, dff=2048
.
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
dropout_rate = 0.1
Optimizer
Use the Adam optimizer with a custom learning rate scheduler according to the formula in the paper.
\[\Large{lrate = d_{model}^{-0.5} * \min(step{\_}num^{-0.5}, step{\_}num \cdot warmup{\_}steps^{-1.5})}\]
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, d_model, warmup_steps=4000):
super(CustomSchedule, self).__init__()
self.d_model = d_model
self.d_model = tf.cast(self.d_model, tf.float32)
self.warmup_steps = warmup_steps
def __call__(self, step):
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps ** -1.5)
return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
epsilon=1e-9)
temp_learning_rate_schedule = CustomSchedule(d_model)
plt.plot(temp_learning_rate_schedule(tf.range(40000, dtype=tf.float32)))
plt.ylabel('Learning Rate')
plt.xlabel('Train Step')
Text(0.5, 0, 'Train Step')
Loss and metrics
Since the target sequences are padded, it is important to apply a padding mask when calculating the loss.
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction='none')
def loss_function(real, pred):
mask = tf.math.logical_not(tf.math.equal(real, 0))
loss_ = loss_object(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask
return tf.reduce_sum(loss_)/tf.reduce_sum(mask)
def accuracy_function(real, pred):
accuracies = tf.equal(real, tf.argmax(pred, axis=2))
mask = tf.math.logical_not(tf.math.equal(real, 0))
accuracies = tf.math.logical_and(mask, accuracies)
accuracies = tf.cast(accuracies, dtype=tf.float32)
mask = tf.cast(mask, dtype=tf.float32)
return tf.reduce_sum(accuracies)/tf.reduce_sum(mask)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')
Training and checkpointing
transformer = Transformer(
num_layers=num_layers,
d_model=d_model,
num_heads=num_heads,
dff=dff,
input_vocab_size=tokenizers.pt.get_vocab_size().numpy(),
target_vocab_size=tokenizers.en.get_vocab_size().numpy(),
rate=dropout_rate)
Create the checkpoint path and the checkpoint manager. This will be used to save checkpoints every n
epochs.
checkpoint_path = './checkpoints/train'
ckpt = tf.train.Checkpoint(transformer=transformer,
optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
print('Latest checkpoint restored!!')
The target is divided into tar_inp and tar_real. tar_inp is passed as an input to the decoder. tar_real
is that same input shifted by 1: At each location in tar_input
, tar_real
contains the next token that should be predicted.
For example, sentence = 'SOS A lion in the jungle is sleeping EOS'
becomes:
tar_inp = 'SOS A lion in the jungle is sleeping'
tar_real = 'A lion in the jungle is sleeping EOS'
A transformer is an auto-regressive model: it makes predictions one part at a time, and uses its output so far to decide what to do next.
During training this example uses teacher-forcing (like in the text generation tutorial). Teacher forcing is passing the true output to the next time step regardless of what the model predicts at the current time step.
As the model predicts each token, self-attention allows it to look at the previous tokens in the input sequence to better predict the next token.
To prevent the model from peeking at the expected output the model uses a look-ahead mask.
EPOCHS = 20
# The @tf.function trace-compiles train_step into a TF graph for faster
# execution. The function specializes to the precise shape of the argument
# tensors. To avoid re-tracing due to the variable sequence lengths or variable
# batch sizes (the last batch is smaller), use input_signature to specify
# more generic shapes.
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]
@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
tar_inp = tar[:, :-1]
tar_real = tar[:, 1:]
with tf.GradientTape() as tape:
predictions, _ = transformer([inp, tar_inp],
training = True)
loss = loss_function(tar_real, predictions)
gradients = tape.gradient(loss, transformer.trainable_variables)
optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
train_loss(loss)
train_accuracy(accuracy_function(tar_real, predictions))
Portuguese is used as the input language and English is the target language.
for epoch in range(EPOCHS):
start = time.time()
train_loss.reset_states()
train_accuracy.reset_states()
# inp -> portuguese, tar -> english
for (batch, (inp, tar)) in enumerate(train_batches):
train_step(inp, tar)
if batch % 50 == 0:
print(f'Epoch {epoch + 1} Batch {batch} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')
if (epoch + 1) % 5 == 0:
ckpt_save_path = ckpt_manager.save()
print(f'Saving checkpoint for epoch {epoch+1} at {ckpt_save_path}')
print(f'Epoch {epoch + 1} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')
print(f'Time taken for 1 epoch: {time.time() - start:.2f} secs\n')
Epoch 1 Batch 0 Loss 8.8531 Accuracy 0.0000 Epoch 1 Batch 50 Loss 8.7949 Accuracy 0.0032 Epoch 1 Batch 100 Loss 8.7005 Accuracy 0.0231 Epoch 1 Batch 150 Loss 8.5879 Accuracy 0.0324 Epoch 1 Batch 200 Loss 8.4469 Accuracy 0.0393 Epoch 1 Batch 250 Loss 8.2763 Accuracy 0.0459 Epoch 1 Batch 300 Loss 8.0874 Accuracy 0.0540 Epoch 1 Batch 350 Loss 7.8902 Accuracy 0.0612 Epoch 1 Batch 400 Loss 7.7009 Accuracy 0.0680 Epoch 1 Batch 450 Loss 7.5321 Accuracy 0.0747 Epoch 1 Batch 500 Loss 7.3834 Accuracy 0.0813 Epoch 1 Batch 550 Loss 7.2496 Accuracy 0.0876 Epoch 1 Batch 600 Loss 7.1250 Accuracy 0.0947 Epoch 1 Batch 650 Loss 7.0107 Accuracy 0.1017 Epoch 1 Batch 700 Loss 6.9043 Accuracy 0.1079 Epoch 1 Loss 6.9043 Accuracy 0.1079 Time taken for 1 epoch: 51.80 secs Epoch 2 Batch 0 Loss 5.3989 Accuracy 0.1918 Epoch 2 Batch 50 Loss 5.4114 Accuracy 0.1993 Epoch 2 Batch 100 Loss 5.3776 Accuracy 0.2016 Epoch 2 Batch 150 Loss 5.3417 Accuracy 0.2039 Epoch 2 Batch 200 Loss 5.3033 Accuracy 0.2074 Epoch 2 Batch 250 Loss 5.2686 Accuracy 0.2102 Epoch 2 Batch 300 Loss 5.2359 Accuracy 0.2137 Epoch 2 Batch 350 Loss 5.2048 Accuracy 0.2169 Epoch 2 Batch 400 Loss 5.1783 Accuracy 0.2193 Epoch 2 Batch 450 Loss 5.1516 Accuracy 0.2221 Epoch 2 Batch 500 Loss 5.1249 Accuracy 0.2248 Epoch 2 Batch 550 Loss 5.1024 Accuracy 0.2267 Epoch 2 Batch 600 Loss 5.0822 Accuracy 0.2288 Epoch 2 Batch 650 Loss 5.0623 Accuracy 0.2307 Epoch 2 Batch 700 Loss 5.0416 Accuracy 0.2327 Epoch 2 Loss 5.0407 Accuracy 0.2328 Time taken for 1 epoch: 38.33 secs Epoch 3 Batch 0 Loss 4.8084 Accuracy 0.2580 Epoch 3 Batch 50 Loss 4.7208 Accuracy 0.2600 Epoch 3 Batch 100 Loss 4.6989 Accuracy 0.2619 Epoch 3 Batch 150 Loss 4.6907 Accuracy 0.2635 Epoch 3 Batch 200 Loss 4.6770 Accuracy 0.2650 Epoch 3 Batch 250 Loss 4.6651 Accuracy 0.2661 Epoch 3 Batch 300 Loss 4.6500 Accuracy 0.2673 Epoch 3 Batch 350 Loss 4.6389 Accuracy 0.2683 Epoch 3 Batch 400 Loss 4.6215 Accuracy 0.2699 Epoch 3 Batch 450 Loss 4.6098 Accuracy 0.2708 Epoch 3 Batch 500 Loss 4.5990 Accuracy 0.2716 Epoch 3 Batch 550 Loss 4.5863 Accuracy 0.2727 Epoch 3 Batch 600 Loss 4.5743 Accuracy 0.2741 Epoch 3 Batch 650 Loss 4.5607 Accuracy 0.2756 Epoch 3 Batch 700 Loss 4.5493 Accuracy 0.2768 Epoch 3 Loss 4.5482 Accuracy 0.2769 Time taken for 1 epoch: 39.29 secs Epoch 4 Batch 0 Loss 4.4872 Accuracy 0.2858 Epoch 4 Batch 50 Loss 4.3147 Accuracy 0.2991 Epoch 4 Batch 100 Loss 4.2780 Accuracy 0.3028 Epoch 4 Batch 150 Loss 4.2661 Accuracy 0.3049 Epoch 4 Batch 200 Loss 4.2495 Accuracy 0.3076 Epoch 4 Batch 250 Loss 4.2335 Accuracy 0.3102 Epoch 4 Batch 300 Loss 4.2165 Accuracy 0.3125 Epoch 4 Batch 350 Loss 4.1987 Accuracy 0.3148 Epoch 4 Batch 400 Loss 4.1804 Accuracy 0.3174 Epoch 4 Batch 450 Loss 4.1628 Accuracy 0.3195 Epoch 4 Batch 500 Loss 4.1445 Accuracy 0.3220 Epoch 4 Batch 550 Loss 4.1258 Accuracy 0.3242 Epoch 4 Batch 600 Loss 4.1078 Accuracy 0.3263 Epoch 4 Batch 650 Loss 4.0939 Accuracy 0.3281 Epoch 4 Loss 4.0790 Accuracy 0.3298 Time taken for 1 epoch: 39.41 secs Epoch 5 Batch 0 Loss 3.7135 Accuracy 0.3579 Epoch 5 Batch 50 Loss 3.7548 Accuracy 0.3668 Epoch 5 Batch 100 Loss 3.7626 Accuracy 0.3662 Epoch 5 Batch 150 Loss 3.7450 Accuracy 0.3689 Epoch 5 Batch 200 Loss 3.7320 Accuracy 0.3708 Epoch 5 Batch 250 Loss 3.7225 Accuracy 0.3722 Epoch 5 Batch 300 Loss 3.7118 Accuracy 0.3738 Epoch 5 Batch 350 Loss 3.6993 Accuracy 0.3756 Epoch 5 Batch 400 Loss 3.6874 Accuracy 0.3777 Epoch 5 Batch 450 Loss 3.6730 Accuracy 0.3797 Epoch 5 Batch 500 Loss 3.6610 Accuracy 0.3813 Epoch 5 Batch 550 Loss 3.6504 Accuracy 0.3827 Epoch 5 Batch 600 Loss 3.6404 Accuracy 0.3842 Epoch 5 Batch 650 Loss 3.6284 Accuracy 0.3856 Saving checkpoint for epoch 5 at ./checkpoints/train/ckpt-1 Epoch 5 Loss 3.6206 Accuracy 0.3868 Time taken for 1 epoch: 39.17 secs Epoch 6 Batch 0 Loss 3.2417 Accuracy 0.4175 Epoch 6 Batch 50 Loss 3.3316 Accuracy 0.4162 Epoch 6 Batch 100 Loss 3.3551 Accuracy 0.4148 Epoch 6 Batch 150 Loss 3.3538 Accuracy 0.4159 Epoch 6 Batch 200 Loss 3.3495 Accuracy 0.4161 Epoch 6 Batch 250 Loss 3.3375 Accuracy 0.4180 Epoch 6 Batch 300 Loss 3.3262 Accuracy 0.4202 Epoch 6 Batch 350 Loss 3.3167 Accuracy 0.4217 Epoch 6 Batch 400 Loss 3.3056 Accuracy 0.4234 Epoch 6 Batch 450 Loss 3.2983 Accuracy 0.4246 Epoch 6 Batch 500 Loss 3.2933 Accuracy 0.4252 Epoch 6 Batch 550 Loss 3.2836 Accuracy 0.4262 Epoch 6 Batch 600 Loss 3.2757 Accuracy 0.4276 Epoch 6 Batch 650 Loss 3.2657 Accuracy 0.4290 Epoch 6 Loss 3.2587 Accuracy 0.4302 Time taken for 1 epoch: 38.71 secs Epoch 7 Batch 0 Loss 3.1061 Accuracy 0.4460 Epoch 7 Batch 50 Loss 3.0378 Accuracy 0.4538 Epoch 7 Batch 100 Loss 3.0175 Accuracy 0.4575 Epoch 7 Batch 150 Loss 2.9996 Accuracy 0.4596 Epoch 7 Batch 200 Loss 2.9934 Accuracy 0.4609 Epoch 7 Batch 250 Loss 2.9823 Accuracy 0.4625 Epoch 7 Batch 300 Loss 2.9793 Accuracy 0.4637 Epoch 7 Batch 350 Loss 2.9726 Accuracy 0.4643 Epoch 7 Batch 400 Loss 2.9654 Accuracy 0.4654 Epoch 7 Batch 450 Loss 2.9562 Accuracy 0.4668 Epoch 7 Batch 500 Loss 2.9461 Accuracy 0.4683 Epoch 7 Batch 550 Loss 2.9369 Accuracy 0.4696 Epoch 7 Batch 600 Loss 2.9327 Accuracy 0.4703 Epoch 7 Batch 650 Loss 2.9274 Accuracy 0.4711 Epoch 7 Loss 2.9200 Accuracy 0.4722 Time taken for 1 epoch: 38.81 secs Epoch 8 Batch 0 Loss 2.5240 Accuracy 0.5190 Epoch 8 Batch 50 Loss 2.6982 Accuracy 0.4984 Epoch 8 Batch 100 Loss 2.6850 Accuracy 0.5006 Epoch 8 Batch 150 Loss 2.6801 Accuracy 0.5018 Epoch 8 Batch 200 Loss 2.6808 Accuracy 0.5025 Epoch 8 Batch 250 Loss 2.6759 Accuracy 0.5035 Epoch 8 Batch 300 Loss 2.6734 Accuracy 0.5039 Epoch 8 Batch 350 Loss 2.6681 Accuracy 0.5046 Epoch 8 Batch 400 Loss 2.6662 Accuracy 0.5048 Epoch 8 Batch 450 Loss 2.6609 Accuracy 0.5054 Epoch 8 Batch 500 Loss 2.6533 Accuracy 0.5064 Epoch 8 Batch 550 Loss 2.6456 Accuracy 0.5077 Epoch 8 Batch 600 Loss 2.6430 Accuracy 0.5080 Epoch 8 Batch 650 Loss 2.6394 Accuracy 0.5086 Epoch 8 Batch 700 Loss 2.6356 Accuracy 0.5094 Epoch 8 Loss 2.6349 Accuracy 0.5095 Time taken for 1 epoch: 39.05 secs Epoch 9 Batch 0 Loss 2.4558 Accuracy 0.5307 Epoch 9 Batch 50 Loss 2.4526 Accuracy 0.5332 Epoch 9 Batch 100 Loss 2.4505 Accuracy 0.5325 Epoch 9 Batch 150 Loss 2.4423 Accuracy 0.5347 Epoch 9 Batch 200 Loss 2.4388 Accuracy 0.5351 Epoch 9 Batch 250 Loss 2.4391 Accuracy 0.5347 Epoch 9 Batch 300 Loss 2.4361 Accuracy 0.5354 Epoch 9 Batch 350 Loss 2.4345 Accuracy 0.5357 Epoch 9 Batch 400 Loss 2.4351 Accuracy 0.5358 Epoch 9 Batch 450 Loss 2.4312 Accuracy 0.5365 Epoch 9 Batch 500 Loss 2.4287 Accuracy 0.5367 Epoch 9 Batch 550 Loss 2.4247 Accuracy 0.5374 Epoch 9 Batch 600 Loss 2.4256 Accuracy 0.5372 Epoch 9 Batch 650 Loss 2.4243 Accuracy 0.5375 Epoch 9 Loss 2.4234 Accuracy 0.5378 Time taken for 1 epoch: 38.24 secs Epoch 10 Batch 0 Loss 2.3818 Accuracy 0.5334 Epoch 10 Batch 50 Loss 2.2105 Accuracy 0.5651 Epoch 10 Batch 100 Loss 2.2428 Accuracy 0.5611 Epoch 10 Batch 150 Loss 2.2566 Accuracy 0.5597 Epoch 10 Batch 200 Loss 2.2575 Accuracy 0.5597 Epoch 10 Batch 250 Loss 2.2565 Accuracy 0.5597 Epoch 10 Batch 300 Loss 2.2574 Accuracy 0.5597 Epoch 10 Batch 350 Loss 2.2590 Accuracy 0.5596 Epoch 10 Batch 400 Loss 2.2587 Accuracy 0.5599 Epoch 10 Batch 450 Loss 2.2545 Accuracy 0.5607 Epoch 10 Batch 500 Loss 2.2550 Accuracy 0.5609 Epoch 10 Batch 550 Loss 2.2548 Accuracy 0.5609 Epoch 10 Batch 600 Loss 2.2558 Accuracy 0.5609 Epoch 10 Batch 650 Loss 2.2562 Accuracy 0.5609 Epoch 10 Batch 700 Loss 2.2566 Accuracy 0.5609 Saving checkpoint for epoch 10 at ./checkpoints/train/ckpt-2 Epoch 10 Loss 2.2561 Accuracy 0.5610 Time taken for 1 epoch: 38.66 secs Epoch 11 Batch 0 Loss 2.0627 Accuracy 0.5842 Epoch 11 Batch 50 Loss 2.1071 Accuracy 0.5802 Epoch 11 Batch 100 Loss 2.0934 Accuracy 0.5819 Epoch 11 Batch 150 Loss 2.1149 Accuracy 0.5791 Epoch 11 Batch 200 Loss 2.1186 Accuracy 0.5793 Epoch 11 Batch 250 Loss 2.1142 Accuracy 0.5802 Epoch 11 Batch 300 Loss 2.1156 Accuracy 0.5800 Epoch 11 Batch 350 Loss 2.1223 Accuracy 0.5791 Epoch 11 Batch 400 Loss 2.1228 Accuracy 0.5791 Epoch 11 Batch 450 Loss 2.1228 Accuracy 0.5791 Epoch 11 Batch 500 Loss 2.1211 Accuracy 0.5795 Epoch 11 Batch 550 Loss 2.1180 Accuracy 0.5802 Epoch 11 Batch 600 Loss 2.1192 Accuracy 0.5802 Epoch 11 Batch 650 Loss 2.1197 Accuracy 0.5801 Epoch 11 Loss 2.1213 Accuracy 0.5800 Time taken for 1 epoch: 38.49 secs Epoch 12 Batch 0 Loss 2.1750 Accuracy 0.5651 Epoch 12 Batch 50 Loss 2.0071 Accuracy 0.5937 Epoch 12 Batch 100 Loss 2.0088 Accuracy 0.5945 Epoch 12 Batch 150 Loss 2.0052 Accuracy 0.5952 Epoch 12 Batch 200 Loss 2.0037 Accuracy 0.5959 Epoch 12 Batch 250 Loss 2.0079 Accuracy 0.5950 Epoch 12 Batch 300 Loss 2.0078 Accuracy 0.5952 Epoch 12 Batch 350 Loss 2.0104 Accuracy 0.5947 Epoch 12 Batch 400 Loss 2.0109 Accuracy 0.5950 Epoch 12 Batch 450 Loss 2.0131 Accuracy 0.5948 Epoch 12 Batch 500 Loss 2.0128 Accuracy 0.5950 Epoch 12 Batch 550 Loss 2.0125 Accuracy 0.5951 Epoch 12 Batch 600 Loss 2.0141 Accuracy 0.5950 Epoch 12 Batch 650 Loss 2.0147 Accuracy 0.5951 Epoch 12 Loss 2.0151 Accuracy 0.5951 Time taken for 1 epoch: 38.48 secs Epoch 13 Batch 0 Loss 1.8535 Accuracy 0.6231 Epoch 13 Batch 50 Loss 1.8855 Accuracy 0.6131 Epoch 13 Batch 100 Loss 1.9094 Accuracy 0.6097 Epoch 13 Batch 150 Loss 1.9151 Accuracy 0.6088 Epoch 13 Batch 200 Loss 1.9140 Accuracy 0.6084 Epoch 13 Batch 250 Loss 1.9135 Accuracy 0.6086 Epoch 13 Batch 300 Loss 1.9110 Accuracy 0.6090 Epoch 13 Batch 350 Loss 1.9140 Accuracy 0.6087 Epoch 13 Batch 400 Loss 1.9127 Accuracy 0.6091 Epoch 13 Batch 450 Loss 1.9143 Accuracy 0.6089 Epoch 13 Batch 500 Loss 1.9151 Accuracy 0.6089 Epoch 13 Batch 550 Loss 1.9166 Accuracy 0.6089 Epoch 13 Batch 600 Loss 1.9195 Accuracy 0.6088 Epoch 13 Batch 650 Loss 1.9208 Accuracy 0.6087 Epoch 13 Loss 1.9229 Accuracy 0.6084 Time taken for 1 epoch: 38.02 secs Epoch 14 Batch 0 Loss 1.7991 Accuracy 0.6328 Epoch 14 Batch 50 Loss 1.8104 Accuracy 0.6226 Epoch 14 Batch 100 Loss 1.8140 Accuracy 0.6226 Epoch 14 Batch 150 Loss 1.8177 Accuracy 0.6217 Epoch 14 Batch 200 Loss 1.8203 Accuracy 0.6218 Epoch 14 Batch 250 Loss 1.8282 Accuracy 0.6208 Epoch 14 Batch 300 Loss 1.8306 Accuracy 0.6207 Epoch 14 Batch 350 Loss 1.8306 Accuracy 0.6211 Epoch 14 Batch 400 Loss 1.8294 Accuracy 0.6214 Epoch 14 Batch 450 Loss 1.8327 Accuracy 0.6210 Epoch 14 Batch 500 Loss 1.8367 Accuracy 0.6204 Epoch 14 Batch 550 Loss 1.8383 Accuracy 0.6201 Epoch 14 Batch 600 Loss 1.8396 Accuracy 0.6200 Epoch 14 Batch 650 Loss 1.8410 Accuracy 0.6199 Epoch 14 Batch 700 Loss 1.8458 Accuracy 0.6193 Epoch 14 Loss 1.8461 Accuracy 0.6192 Time taken for 1 epoch: 38.29 secs Epoch 15 Batch 0 Loss 1.4491 Accuracy 0.6872 Epoch 15 Batch 50 Loss 1.7225 Accuracy 0.6397 Epoch 15 Batch 100 Loss 1.7232 Accuracy 0.6393 Epoch 15 Batch 150 Loss 1.7282 Accuracy 0.6374 Epoch 15 Batch 200 Loss 1.7416 Accuracy 0.6355 Epoch 15 Batch 250 Loss 1.7470 Accuracy 0.6346 Epoch 15 Batch 300 Loss 1.7537 Accuracy 0.6335 Epoch 15 Batch 350 Loss 1.7612 Accuracy 0.6320 Epoch 15 Batch 400 Loss 1.7613 Accuracy 0.6319 Epoch 15 Batch 450 Loss 1.7618 Accuracy 0.6318 Epoch 15 Batch 500 Loss 1.7660 Accuracy 0.6313 Epoch 15 Batch 550 Loss 1.7709 Accuracy 0.6308 Epoch 15 Batch 600 Loss 1.7735 Accuracy 0.6304 Epoch 15 Batch 650 Loss 1.7753 Accuracy 0.6302 Epoch 15 Batch 700 Loss 1.7784 Accuracy 0.6298 Saving checkpoint for epoch 15 at ./checkpoints/train/ckpt-3 Epoch 15 Loss 1.7783 Accuracy 0.6299 Time taken for 1 epoch: 38.42 secs Epoch 16 Batch 0 Loss 1.6642 Accuracy 0.6440 Epoch 16 Batch 50 Loss 1.6642 Accuracy 0.6453 Epoch 16 Batch 100 Loss 1.6646 Accuracy 0.6460 Epoch 16 Batch 150 Loss 1.6753 Accuracy 0.6445 Epoch 16 Batch 200 Loss 1.6863 Accuracy 0.6434 Epoch 16 Batch 250 Loss 1.6979 Accuracy 0.6420 Epoch 16 Batch 300 Loss 1.7000 Accuracy 0.6416 Epoch 16 Batch 350 Loss 1.7015 Accuracy 0.6412 Epoch 16 Batch 400 Loss 1.7025 Accuracy 0.6411 Epoch 16 Batch 450 Loss 1.7016 Accuracy 0.6412 Epoch 16 Batch 500 Loss 1.7048 Accuracy 0.6407 Epoch 16 Batch 550 Loss 1.7072 Accuracy 0.6402 Epoch 16 Batch 600 Loss 1.7111 Accuracy 0.6396 Epoch 16 Batch 650 Loss 1.7135 Accuracy 0.6393 Epoch 16 Batch 700 Loss 1.7172 Accuracy 0.6387 Epoch 16 Loss 1.7178 Accuracy 0.6386 Time taken for 1 epoch: 38.71 secs Epoch 17 Batch 0 Loss 1.6321 Accuracy 0.6582 Epoch 17 Batch 50 Loss 1.6246 Accuracy 0.6498 Epoch 17 Batch 100 Loss 1.6288 Accuracy 0.6508 Epoch 17 Batch 150 Loss 1.6266 Accuracy 0.6521 Epoch 17 Batch 200 Loss 1.6304 Accuracy 0.6516 Epoch 17 Batch 250 Loss 1.6324 Accuracy 0.6516 Epoch 17 Batch 300 Loss 1.6368 Accuracy 0.6505 Epoch 17 Batch 350 Loss 1.6414 Accuracy 0.6498 Epoch 17 Batch 400 Loss 1.6423 Accuracy 0.6498 Epoch 17 Batch 450 Loss 1.6446 Accuracy 0.6494 Epoch 17 Batch 500 Loss 1.6467 Accuracy 0.6490 Epoch 17 Batch 550 Loss 1.6521 Accuracy 0.6481 Epoch 17 Batch 600 Loss 1.6561 Accuracy 0.6477 Epoch 17 Batch 650 Loss 1.6605 Accuracy 0.6472 Epoch 17 Batch 700 Loss 1.6640 Accuracy 0.6467 Epoch 17 Loss 1.6651 Accuracy 0.6464 Time taken for 1 epoch: 38.36 secs Epoch 18 Batch 0 Loss 1.7109 Accuracy 0.6334 Epoch 18 Batch 50 Loss 1.5447 Accuracy 0.6661 Epoch 18 Batch 100 Loss 1.5643 Accuracy 0.6621 Epoch 18 Batch 150 Loss 1.5763 Accuracy 0.6604 Epoch 18 Batch 200 Loss 1.5792 Accuracy 0.6600 Epoch 18 Batch 250 Loss 1.5860 Accuracy 0.6589 Epoch 18 Batch 300 Loss 1.5886 Accuracy 0.6585 Epoch 18 Batch 350 Loss 1.5887 Accuracy 0.6586 Epoch 18 Batch 400 Loss 1.5916 Accuracy 0.6578 Epoch 18 Batch 450 Loss 1.5932 Accuracy 0.6578 Epoch 18 Batch 500 Loss 1.5976 Accuracy 0.6570 Epoch 18 Batch 550 Loss 1.6019 Accuracy 0.6562 Epoch 18 Batch 600 Loss 1.6049 Accuracy 0.6559 Epoch 18 Batch 650 Loss 1.6076 Accuracy 0.6556 Epoch 18 Loss 1.6116 Accuracy 0.6551 Time taken for 1 epoch: 38.72 secs Epoch 19 Batch 0 Loss 1.3450 Accuracy 0.7062 Epoch 19 Batch 50 Loss 1.5109 Accuracy 0.6695 Epoch 19 Batch 100 Loss 1.5241 Accuracy 0.6668 Epoch 19 Batch 150 Loss 1.5378 Accuracy 0.6658 Epoch 19 Batch 200 Loss 1.5423 Accuracy 0.6653 Epoch 19 Batch 250 Loss 1.5515 Accuracy 0.6639 Epoch 19 Batch 300 Loss 1.5542 Accuracy 0.6637 Epoch 19 Batch 350 Loss 1.5553 Accuracy 0.6637 Epoch 19 Batch 400 Loss 1.5595 Accuracy 0.6632 Epoch 19 Batch 450 Loss 1.5605 Accuracy 0.6627 Epoch 19 Batch 500 Loss 1.5626 Accuracy 0.6624 Epoch 19 Batch 550 Loss 1.5645 Accuracy 0.6622 Epoch 19 Batch 600 Loss 1.5692 Accuracy 0.6614 Epoch 19 Batch 650 Loss 1.5715 Accuracy 0.6612 Epoch 19 Batch 700 Loss 1.5733 Accuracy 0.6609 Epoch 19 Loss 1.5733 Accuracy 0.6609 Time taken for 1 epoch: 38.43 secs Epoch 20 Batch 0 Loss 1.5804 Accuracy 0.6560 Epoch 20 Batch 50 Loss 1.4986 Accuracy 0.6732 Epoch 20 Batch 100 Loss 1.4969 Accuracy 0.6729 Epoch 20 Batch 150 Loss 1.5063 Accuracy 0.6710 Epoch 20 Batch 200 Loss 1.5107 Accuracy 0.6704 Epoch 20 Batch 250 Loss 1.5119 Accuracy 0.6701 Epoch 20 Batch 300 Loss 1.5132 Accuracy 0.6700 Epoch 20 Batch 350 Loss 1.5146 Accuracy 0.6699 Epoch 20 Batch 400 Loss 1.5157 Accuracy 0.6696 Epoch 20 Batch 450 Loss 1.5170 Accuracy 0.6694 Epoch 20 Batch 500 Loss 1.5196 Accuracy 0.6690 Epoch 20 Batch 550 Loss 1.5228 Accuracy 0.6684 Epoch 20 Batch 600 Loss 1.5269 Accuracy 0.6679 Epoch 20 Batch 650 Loss 1.5290 Accuracy 0.6674 Epoch 20 Batch 700 Loss 1.5324 Accuracy 0.6670 Saving checkpoint for epoch 20 at ./checkpoints/train/ckpt-4 Epoch 20 Loss 1.5323 Accuracy 0.6671 Time taken for 1 epoch: 38.24 secs
Run inference
The following steps are used for inference:
- Encode the input sentence using the Portuguese tokenizer (
tokenizers.pt
). This is the encoder input. - The decoder input is initialized to the
[START]
token. - Calculate the padding masks and the look ahead masks.
- The
decoder
then outputs the predictions by looking at theencoder output
and its own output (self-attention). - Concatenate the predicted token to the decoder input and pass it to the decoder.
- In this approach, the decoder predicts the next token based on the previous tokens it predicted.
class Translator(tf.Module):
def __init__(self, tokenizers, transformer):
self.tokenizers = tokenizers
self.transformer = transformer
def __call__(self, sentence, max_length=MAX_TOKENS):
# input sentence is portuguese, hence adding the start and end token
assert isinstance(sentence, tf.Tensor)
if len(sentence.shape) == 0:
sentence = sentence[tf.newaxis]
sentence = self.tokenizers.pt.tokenize(sentence).to_tensor()
encoder_input = sentence
# As the output language is english, initialize the output with the
# english start token.
start_end = self.tokenizers.en.tokenize([''])[0]
start = start_end[0][tf.newaxis]
end = start_end[1][tf.newaxis]
# `tf.TensorArray` is required here (instead of a python list) so that the
# dynamic-loop can be traced by `tf.function`.
output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
output_array = output_array.write(0, start)
for i in tf.range(max_length):
output = tf.transpose(output_array.stack())
predictions, _ = self.transformer([encoder_input, output], training=False)
# select the last token from the seq_len dimension
predictions = predictions[:, -1:, :] # (batch_size, 1, vocab_size)
predicted_id = tf.argmax(predictions, axis=-1)
# concatentate the predicted_id to the output which is given to the decoder
# as its input.
output_array = output_array.write(i+1, predicted_id[0])
if predicted_id == end:
break
output = tf.transpose(output_array.stack())
# output.shape (1, tokens)
text = tokenizers.en.detokenize(output)[0] # shape: ()
tokens = tokenizers.en.lookup(output)[0]
# `tf.function` prevents us from using the attention_weights that were
# calculated on the last iteration of the loop. So recalculate them outside
# the loop.
_, attention_weights = self.transformer([encoder_input, output[:,:-1]], training=False)
return text, tokens, attention_weights
Create an instance of this Translator
class, and try it out a few times:
translator = Translator(tokenizers, transformer)
def print_translation(sentence, tokens, ground_truth):
print(f'{"Input:":15s}: {sentence}')
print(f'{"Prediction":15s}: {tokens.numpy().decode("utf-8")}')
print(f'{"Ground truth":15s}: {ground_truth}')
sentence = 'este é um problema que temos que resolver.'
ground_truth = 'this is a problem we have to solve .'
translated_text, translated_tokens, attention_weights = translator(
tf.constant(sentence))
print_translation(sentence, translated_text, ground_truth)
Input: : este é um problema que temos que resolver. Prediction : this is a problem that we have to solve . Ground truth : this is a problem we have to solve .
sentence = 'os meus vizinhos ouviram sobre esta ideia.'
ground_truth = 'and my neighboring homes heard about this idea .'
translated_text, translated_tokens, attention_weights = translator(
tf.constant(sentence))
print_translation(sentence, translated_text, ground_truth)
Input: : os meus vizinhos ouviram sobre esta ideia. Prediction : my neighbors heard about this idea . Ground truth : and my neighboring homes heard about this idea .
sentence = 'vou então muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram.'
ground_truth = "so i'll just share with you some stories very quickly of some magical things that have happened."
translated_text, translated_tokens, attention_weights = translator(
tf.constant(sentence))
print_translation(sentence, translated_text, ground_truth)
Input: : vou então muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram. Prediction : so i ' ll really share with you with some stories of some magical things that happened . Ground truth : so i'll just share with you some stories very quickly of some magical things that have happened.
Attention plots
The Translator
class returns a dictionary of attention maps you can use to visualize the internal working of the model:
sentence = 'este é o primeiro livro que eu fiz.'
ground_truth = "this is the first book i've ever done."
translated_text, translated_tokens, attention_weights = translator(
tf.constant(sentence))
print_translation(sentence, translated_text, ground_truth)
Input: : este é o primeiro livro que eu fiz. Prediction : this is the first book i did . Ground truth : this is the first book i've ever done.
def plot_attention_head(in_tokens, translated_tokens, attention):
# The plot is of the attention when a token was generated.
# The model didn't generate `<START>` in the output. Skip it.
translated_tokens = translated_tokens[1:]
ax = plt.gca()
ax.matshow(attention)
ax.set_xticks(range(len(in_tokens)))
ax.set_yticks(range(len(translated_tokens)))
labels = [label.decode('utf-8') for label in in_tokens.numpy()]
ax.set_xticklabels(
labels, rotation=90)
labels = [label.decode('utf-8') for label in translated_tokens.numpy()]
ax.set_yticklabels(labels)
head = 0
# shape: (batch=1, num_heads, seq_len_q, seq_len_k)
attention_heads = tf.squeeze(
attention_weights['decoder_layer4_block2'], 0)
attention = attention_heads[head]
attention.shape
TensorShape([9, 11])
in_tokens = tf.convert_to_tensor([sentence])
in_tokens = tokenizers.pt.tokenize(in_tokens).to_tensor()
in_tokens = tokenizers.pt.lookup(in_tokens)[0]
in_tokens
<tf.Tensor: shape=(11,), dtype=string, numpy= array([b'[START]', b'este', b'e', b'o', b'primeiro', b'livro', b'que', b'eu', b'fiz', b'.', b'[END]'], dtype=object)>
translated_tokens
<tf.Tensor: shape=(10,), dtype=string, numpy= array([b'[START]', b'this', b'is', b'the', b'first', b'book', b'i', b'did', b'.', b'[END]'], dtype=object)>
plot_attention_head(in_tokens, translated_tokens, attention)
def plot_attention_weights(sentence, translated_tokens, attention_heads):
in_tokens = tf.convert_to_tensor([sentence])
in_tokens = tokenizers.pt.tokenize(in_tokens).to_tensor()
in_tokens = tokenizers.pt.lookup(in_tokens)[0]
in_tokens
fig = plt.figure(figsize=(16, 8))
for h, head in enumerate(attention_heads):
ax = fig.add_subplot(2, 4, h+1)
plot_attention_head(in_tokens, translated_tokens, head)
ax.set_xlabel(f'Head {h+1}')
plt.tight_layout()
plt.show()
plot_attention_weights(sentence, translated_tokens,
attention_weights['decoder_layer4_block2'][0])
The model does okay on unfamiliar words. Neither "triceratops" or "encyclopedia" are in the input dataset and the model almost learns to transliterate them, even without a shared vocabulary:
sentence = 'Eu li sobre triceratops na enciclopédia.'
ground_truth = 'I read about triceratops in the encyclopedia.'
translated_text, translated_tokens, attention_weights = translator(
tf.constant(sentence))
print_translation(sentence, translated_text, ground_truth)
plot_attention_weights(sentence, translated_tokens,
attention_weights['decoder_layer4_block2'][0])
Input: : Eu li sobre triceratops na enciclopédia. Prediction : i read about triopatrinas in esclodies . Ground truth : I read about triceratops in the encyclopedia.
Export
That inference model is working, so next you'll export it as a tf.saved_model
.
To do that, wrap it in yet another tf.Module
sub-class, this time with a tf.function
on the __call__
method:
class ExportTranslator(tf.Module):
def __init__(self, translator):
self.translator = translator
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
def __call__(self, sentence):
(result,
tokens,
attention_weights) = self.translator(sentence, max_length=MAX_TOKENS)
return result
In the above tf.function
only the output sentence is returned. Thanks to the non-strict execution in tf.function
any unnecessary values are never computed.
translator = ExportTranslator(translator)
Since the model is decoding the predictions using tf.argmax
the predictions are deterministic. The original model and one reloaded from its SavedModel
should give identical predictions:
translator('este é o primeiro livro que eu fiz.').numpy()
b'this is the first book i did .'
tf.saved_model.save(translator, export_dir='translator')
2022-05-04 11:21:55.800781: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. WARNING:absl:Found untraced functions such as embedding_4_layer_call_fn, embedding_4_layer_call_and_return_conditional_losses, dropout_37_layer_call_fn, dropout_37_layer_call_and_return_conditional_losses, embedding_5_layer_call_fn while saving (showing 5 of 224). These functions will not be directly callable after loading.
reloaded = tf.saved_model.load('translator')
reloaded('este é o primeiro livro que eu fiz.').numpy()
b'this is the first book i did .'
Summary
In this tutorial you learned about:
- positional encoding
- multi-head attention
- the importance of masking
- and how to put it all together to build a transformer.
This implementation tried to stay close to the implementation of the original paper. If you want to practice there are many things you could try with it. For example:
- Using a different dataset to train the transformer.
- Create the "Base Transformer" or "Transformer XL" configurations from the original paper by changing the hyperparameters.
- Use the layers defined here to create an implementation of BERT.
- Implement beam search to get better predictions.