Google I/O is a wrap! Catch up on TensorFlow sessions View sessions

Transformer model for language understanding

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial trains a transformer model to translate a Portuguese to English dataset.

This is an advanced example that assumes knowledge of text generation and attention.

This tutorial demonstrates how to build a transformer model and most of its components from scratch using low-level TensorFlow and Keras functionalities. Some of this could be minimized if you took advantage of built-in APIs like tf.keras.layers.MultiHeadAttention.

The core idea behind a transformer model is self-attention—the ability to attend to different positions of the input sequence to compute a representation of that sequence. Transformer creates stacks of self-attention layers and is explained below in the sections Scaled dot product attention and Multi-head attention.

A transformer model handles variable-sized input using stacks of self-attention layers instead of RNNs or CNNs. This general architecture has a number of advantages:

  • It makes no assumptions about the temporal/spatial relationships across the data. This is ideal for processing a set of objects (for example, StarCraft units).
  • Layer outputs can be calculated in parallel, instead of a series like an RNN.
  • Distant items can affect each other's output without passing through many RNN-steps, or convolution layers (see Scene Memory Transformer for example).
  • It can learn long-range dependencies. This is a challenge in many sequence tasks.

The downsides of this architecture are:

  • For a time-series, the output for a time-step is calculated from the entire history instead of only the inputs and current hidden-state. This may be less efficient.
  • If the input does have a temporal/spatial relationship, like text, some positional encoding must be added or the model will effectively see a bag of words.

After training the model in this notebook, you will be able to input a Portuguese sentence and return the English translation.

Attention heatmap

Setup

pip install tensorflow_datasets
pip install -U 'tensorflow-text==2.8.*'
import logging
import time

import numpy as np
import matplotlib.pyplot as plt

import tensorflow_datasets as tfds
import tensorflow as tf

# Import tf_text to load the ops used by the tokenizer saved model
import tensorflow_text  # pylint: disable=unused-import
logging.getLogger('tensorflow').setLevel(logging.ERROR)  # suppress warnings

Download the Dataset

Use TensorFlow datasets to load the Portuguese-English translation dataset from the TED Talks Open Translation Project.

This dataset contains approximately 50000 training examples, 1100 validation examples, and 2000 test examples.

examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True,
                               as_supervised=True)
train_examples, val_examples = examples['train'], examples['validation']

The tf.data.Dataset object returned by TensorFlow datasets yields pairs of text examples:

for pt_examples, en_examples in train_examples.batch(3).take(1):
  for pt in pt_examples.numpy():
    print(pt.decode('utf-8'))

  print()

  for en in en_examples.numpy():
    print(en.decode('utf-8'))
e quando melhoramos a procura , tiramos a única vantagem da impressão , que é a serendipidade .
mas e se estes fatores fossem ativos ?
mas eles não tinham a curiosidade de me testar .

and when you improve searchability , you actually take away the one advantage of print , which is serendipity .
but what if it were active ?
but they did n't test for curiosity .

Text tokenization & detokenization

You can't train a model directly on text. The text needs to be converted to some numeric representation first. Typically, you convert the text to sequences of token IDs, which are used as indices into an embedding.

One popular implementation is demonstrated in the Subword tokenizer tutorial builds subword tokenizers (text.BertTokenizer) optimized for this dataset and exports them in a saved_model.

Download and unzip and import the saved_model:

model_name = 'ted_hrlr_translate_pt_en_converter'
tf.keras.utils.get_file(
    f'{model_name}.zip',
    f'https://storage.googleapis.com/download.tensorflow.org/models/{model_name}.zip',
    cache_dir='.', cache_subdir='', extract=True
)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/ted_hrlr_translate_pt_en_converter.zip
188416/184801 [==============================] - 0s 0us/step
196608/184801 [===============================] - 0s 0us/step
'./ted_hrlr_translate_pt_en_converter.zip'
tokenizers = tf.saved_model.load(model_name)

The tf.saved_model contains two text tokenizers, one for English and one for Portuguese. Both have the same methods:

[item for item in dir(tokenizers.en) if not item.startswith('_')]
['detokenize',
 'get_reserved_tokens',
 'get_vocab_path',
 'get_vocab_size',
 'lookup',
 'tokenize',
 'tokenizer',
 'vocab']

The tokenize method converts a batch of strings to a padded-batch of token IDs. This method splits punctuation, lowercases and unicode-normalizes the input before tokenizing. That standardization is not visible here because the input data is already standardized.

for en in en_examples.numpy():
  print(en.decode('utf-8'))
and when you improve searchability , you actually take away the one advantage of print , which is serendipity .
but what if it were active ?
but they did n't test for curiosity .
encoded = tokenizers.en.tokenize(en_examples)

for row in encoded.to_list():
  print(row)
[2, 72, 117, 79, 1259, 1491, 2362, 13, 79, 150, 184, 311, 71, 103, 2308, 74, 2679, 13, 148, 80, 55, 4840, 1434, 2423, 540, 15, 3]
[2, 87, 90, 107, 76, 129, 1852, 30, 3]
[2, 87, 83, 149, 50, 9, 56, 664, 85, 2512, 15, 3]

The detokenize method attempts to convert these token IDs back to human readable text:

round_trip = tokenizers.en.detokenize(encoded)
for line in round_trip.numpy():
  print(line.decode('utf-8'))
and when you improve searchability , you actually take away the one advantage of print , which is serendipity .
but what if it were active ?
but they did n ' t test for curiosity .

The lower level lookup method converts from token-IDs to token text:

tokens = tokenizers.en.lookup(encoded)
tokens
<tf.RaggedTensor [[b'[START]', b'and', b'when', b'you', b'improve', b'search', b'##ability',
  b',', b'you', b'actually', b'take', b'away', b'the', b'one', b'advantage',
  b'of', b'print', b',', b'which', b'is', b's', b'##ere', b'##nd', b'##ip',
  b'##ity', b'.', b'[END]']                                                 ,
 [b'[START]', b'but', b'what', b'if', b'it', b'were', b'active', b'?',
  b'[END]']                                                           ,
 [b'[START]', b'but', b'they', b'did', b'n', b"'", b't', b'test', b'for',
  b'curiosity', b'.', b'[END]']                                          ]>

Here you can see the "subword" aspect of the tokenizers. The word "searchability" is decomposed into "search ##ability" and the word "serendipity" into "s ##ere ##nd ##ip ##ity"

Now take a minute to investigate the distribution of tokens per example in the dataset:

lengths = []

for pt_examples, en_examples in train_examples.batch(1024):
  pt_tokens = tokenizers.en.tokenize(pt_examples)
  lengths.append(pt_tokens.row_lengths())

  en_tokens = tokenizers.en.tokenize(en_examples)
  lengths.append(en_tokens.row_lengths())
  print('.', end='', flush=True)
...................................................
all_lengths = np.concatenate(lengths)

plt.hist(all_lengths, np.linspace(0, 500, 101))
plt.ylim(plt.ylim())
max_length = max(all_lengths)
plt.plot([max_length, max_length], plt.ylim())
plt.title(f'Max tokens per example: {max_length}');

png

MAX_TOKENS = 128

Setup input pipeline

To build an input pipeline suitable for training define some functions to transform the dataset.

Define a function to drop the examples longer than MAX_TOKENS:

def filter_max_tokens(pt, en):
  num_tokens = tf.maximum(tf.shape(pt)[1],tf.shape(en)[1])
  return num_tokens < MAX_TOKENS

Define a function that tokenizes the batches of raw text:

def tokenize_pairs(pt, en):
    pt = tokenizers.pt.tokenize(pt)
    # Convert from ragged to dense, padding with zeros.
    pt = pt.to_tensor()

    en = tokenizers.en.tokenize(en)
    # Convert from ragged to dense, padding with zeros.
    en = en.to_tensor()
    return pt, en

Here's a simple input pipeline that processes, shuffles and batches the data:

BUFFER_SIZE = 20000
BATCH_SIZE = 64
def make_batches(ds):
  return (
      ds
      .cache()
      .shuffle(BUFFER_SIZE)
      .batch(BATCH_SIZE)
      .map(tokenize_pairs, num_parallel_calls=tf.data.AUTOTUNE)
      .filter(filter_max_tokens)
      .prefetch(tf.data.AUTOTUNE))


train_batches = make_batches(train_examples)
val_batches = make_batches(val_examples)

Positional encoding

Attention layers see their input as a set of vectors, with no sequential order. This model also doesn't contain any recurrent or convolutional layers. Because of this a "positional encoding" is added to give the model some information about the relative position of the tokens in the sentence.

The positional encoding vector is added to the embedding vector. Embeddings represent a token in a d-dimensional space where tokens with similar meaning will be closer to each other. But the embeddings do not encode the relative position of tokens in a sentence. So after adding the positional encoding, tokens will be closer to each other based on the similarity of their meaning and their position in the sentence, in the d-dimensional space.

The formula for calculating the positional encoding is as follows:

\[\Large{PE_{(pos, 2i)} = \sin(pos / 10000^{2i / d_{model} })} \]

\[\Large{PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i / d_{model} })} \]

def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates
def positional_encoding(position, d_model):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

  pos_encoding = angle_rads[np.newaxis, ...]

  return tf.cast(pos_encoding, dtype=tf.float32)
n, d = 2048, 512
pos_encoding = positional_encoding(n, d)
print(pos_encoding.shape)
pos_encoding = pos_encoding[0]

# Juggle the dimensions for the plot
pos_encoding = tf.reshape(pos_encoding, (n, d//2, 2))
pos_encoding = tf.transpose(pos_encoding, (2, 1, 0))
pos_encoding = tf.reshape(pos_encoding, (d, n))

plt.pcolormesh(pos_encoding, cmap='RdBu')
plt.ylabel('Depth')
plt.xlabel('Position')
plt.colorbar()
plt.show()
(1, 2048, 512)

png

Masking

Mask all the pad tokens in the batch of sequence. It ensures that the model does not treat padding as the input. The mask indicates where pad value 0 is present: it outputs a 1 at those locations, and a 0 otherwise.

def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)
x = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
create_padding_mask(x)
<tf.Tensor: shape=(3, 1, 1, 5), dtype=float32, numpy=
array([[[[0., 0., 1., 1., 0.]]],


       [[[0., 0., 0., 1., 1.]]],


       [[[1., 1., 1., 0., 0.]]]], dtype=float32)>

The look-ahead mask is used to mask the future tokens in a sequence. In other words, the mask indicates which entries should not be used.

This means that to predict the third token, only the first and second token will be used. Similarly to predict the fourth token, only the first, second and the third tokens will be used and so on.

def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)
x = tf.random.uniform((1, 3))
temp = create_look_ahead_mask(x.shape[1])
temp
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0., 1., 1.],
       [0., 0., 1.],
       [0., 0., 0.]], dtype=float32)>
# def create_look_ahead_mask(size):
#     n = int(size * (size+1) / 2)
#     mask = tfp.math.fill_triangular(tf.ones((n,), dtype=tf.int32), upper=False)

Scaled dot product attention

scaled_dot_product_attention

The attention function used by a transformer takes three inputs: Q (query), K (key), V (value). The equation used to calculate the attention weights is:

\[\Large{Attention(Q, K, V) = softmax_k\left(\frac{QK^T}{\sqrt{d_k} }\right) V} \]

The dot-product attention is scaled by a factor of square root of the depth. This is done because for large values of depth, the dot product grows large in magnitude pushing the softmax function where it has small gradients resulting in a very hard softmax.

For example, consider that Q and K have a mean of 0 and variance of 1. Their matrix multiplication will have a mean of 0 and variance of dk. So the square root of dk is used for scaling, so you get a consistent variance regardless of the value of dk. If the variance is too low the output may be too flat to optimize effectively. If the variance is too high the softmax may saturate at initialization making it difficult to learn.

The mask is multiplied with -1e9 (close to negative infinity). This is done because the mask is summed with the scaled matrix multiplication of Q and K and is applied immediately before a softmax. The goal is to zero out these cells, and large negative inputs to softmax are near zero in the output.

def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead)
  but it must be broadcastable for addition.

  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable
          to (..., seq_len_q, seq_len_k). Defaults to None.

  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights

As the softmax normalization is done on K, its values decide the amount of importance given to Q.

The output represents the multiplication of the attention weights and the V (value) vector. This ensures that the tokens you want to focus on are kept as-is and the irrelevant tokens are flushed out.

def print_out(q, k, v):
  temp_out, temp_attn = scaled_dot_product_attention(
      q, k, v, None)
  print('Attention weights are:')
  print(temp_attn)
  print('Output is:')
  print(temp_out)
np.set_printoptions(suppress=True)

temp_k = tf.constant([[10, 0, 0],
                      [0, 10, 0],
                      [0, 0, 10],
                      [0, 0, 10]], dtype=tf.float32)  # (4, 3)

temp_v = tf.constant([[1, 0],
                      [10, 0],
                      [100, 5],
                      [1000, 6]], dtype=tf.float32)  # (4, 2)

# This `query` aligns with the second `key`,
# so the second `value` is returned.
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[10.  0.]], shape=(1, 2), dtype=float32)
# This query aligns with a repeated key (third and fourth),
# so all associated values get averaged.
temp_q = tf.constant([[0, 0, 10]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor([[0.  0.  0.5 0.5]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[550.    5.5]], shape=(1, 2), dtype=float32)
# This query aligns equally with the first and second key,
# so their values get averaged.
temp_q = tf.constant([[10, 10, 0]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor([[0.5 0.5 0.  0. ]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[5.5 0. ]], shape=(1, 2), dtype=float32)

Pass all the queries together.

temp_q = tf.constant([[0, 0, 10],
                      [0, 10, 0],
                      [10, 10, 0]], dtype=tf.float32)  # (3, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor(
[[0.  0.  0.5 0.5]
 [0.  1.  0.  0. ]
 [0.5 0.5 0.  0. ]], shape=(3, 4), dtype=float32)
Output is:
tf.Tensor(
[[550.    5.5]
 [ 10.    0. ]
 [  5.5   0. ]], shape=(3, 2), dtype=float32)

Multi-head attention

multi-head attention

Multi-head attention consists of four parts:

  • Linear layers.
  • Scaled dot-product attention.
  • Final linear layer.

Each multi-head attention block gets three inputs; Q (query), K (key), V (value). These are put through linear (Dense) layers before the multi-head attention function.

In the diagram above (K,Q,V) are passed through sepearte linear (Dense) layers for each attention head. For simplicity/efficiency the code below implements this using a single dense layer with num_heads times as many outputs. The output is rearranged to a shape of (batch, num_heads, ...) before applying the attention function.

The scaled_dot_product_attention function defined above is applied in a single call, broadcasted for efficiency. An appropriate mask must be used in the attention step. The attention output for each head is then concatenated (using tf.transpose, and tf.reshape) and put through a final Dense layer.

Instead of one single attention head, Q, K, and V are split into multiple heads because it allows the model to jointly attend to information from different representation subspaces at different positions. After the split each head has a reduced dimensionality, so the total computation cost is the same as a single head attention with full dimensionality.

class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self,*, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model

    assert d_model % self.num_heads == 0

    self.depth = d_model // self.num_heads

    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)

    self.dense = tf.keras.layers.Dense(d_model)

  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

  def call(self, v, k, q, mask):
    batch_size = tf.shape(q)[0]

    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)

    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)

    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    concat_attention = tf.reshape(scaled_attention,
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

    return output, attention_weights

Create a MultiHeadAttention layer to try out. At each location in the sequence, y, the MultiHeadAttention runs all 8 attention heads across all other locations in the sequence, returning a new vector of the same length at each location.

temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
y = tf.random.uniform((1, 60, 512))  # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, k=y, q=y, mask=None)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 8, 60, 60]))

Point wise feed forward network

Point wise feed forward network consists of two fully-connected layers with a ReLU activation in between.

def point_wise_feed_forward_network(d_model, dff):
  return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
  ])
sample_ffn = point_wise_feed_forward_network(512, 2048)
sample_ffn(tf.random.uniform((64, 50, 512))).shape
TensorShape([64, 50, 512])

Encoder and decoder

transformer

A transformer model follows the same general pattern as a standard sequence to sequence with attention model.

  • The input sentence is passed through N encoder layers that generates an output for each token in the sequence.
  • The decoder attends to the encoder's output and its own input (self-attention) to predict the next word.

Encoder layer

Each encoder layer consists of sublayers:

  1. Multi-head attention (with padding mask)
  2. Point wise feed forward networks.

Each of these sublayers has a residual connection around it followed by a layer normalization. Residual connections help in avoiding the vanishing gradient problem in deep networks.

The output of each sublayer is LayerNorm(x + Sublayer(x)). The normalization is done on the d_model (last) axis. There are N encoder layers in a transformer.

class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self,*, d_model, num_heads, dff, rate=0.1):
    super(EncoderLayer, self).__init__()

    self.mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    attn_output, _ = self.mha(x, x, x, mask)  # (batch_size, input_seq_len, d_model)
    attn_output = self.dropout1(attn_output, training=training)
    out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)

    ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
    ffn_output = self.dropout2(ffn_output, training=training)
    out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)

    return out2
sample_encoder_layer = EncoderLayer(d_model=512, num_heads=8, dff=2048)

sample_encoder_layer_output = sample_encoder_layer(
    tf.random.uniform((64, 43, 512)), False, None)

sample_encoder_layer_output.shape  # (batch_size, input_seq_len, d_model)
TensorShape([64, 43, 512])

Decoder layer

Each decoder layer consists of sublayers:

  1. Masked multi-head attention (with look ahead mask and padding mask)
  2. Multi-head attention (with padding mask). V (value) and K (key) receive the encoder output as inputs. Q (query) receives the output from the masked multi-head attention sublayer.
  3. Point wise feed forward networks

Each of these sublayers has a residual connection around it followed by a layer normalization. The output of each sublayer is LayerNorm(x + Sublayer(x)). The normalization is done on the d_model (last) axis.

There are a number of decoder layers in the model.

As Q receives the output from decoder's first attention block, and K receives the encoder output, the attention weights represent the importance given to the decoder's input based on the encoder's output. In other words, the decoder predicts the next token by looking at the encoder output and self-attending to its own output. See the demonstration above in the scaled dot product attention section.

class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self,*, d_model, num_heads, dff, rate=0.1):
    super(DecoderLayer, self).__init__()

    self.mha1 = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
    self.mha2 = MultiHeadAttention(d_model=d_model, num_heads=num_heads)

    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)
    self.dropout3 = tf.keras.layers.Dropout(rate)

  def call(self, x, enc_output, training,
           look_ahead_mask, padding_mask):
    # enc_output.shape == (batch_size, input_seq_len, d_model)

    attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)  # (batch_size, target_seq_len, d_model)
    attn1 = self.dropout1(attn1, training=training)
    out1 = self.layernorm1(attn1 + x)

    attn2, attn_weights_block2 = self.mha2(
        enc_output, enc_output, out1, padding_mask)  # (batch_size, target_seq_len, d_model)
    attn2 = self.dropout2(attn2, training=training)
    out2 = self.layernorm2(attn2 + out1)  # (batch_size, target_seq_len, d_model)

    ffn_output = self.ffn(out2)  # (batch_size, target_seq_len, d_model)
    ffn_output = self.dropout3(ffn_output, training=training)
    out3 = self.layernorm3(ffn_output + out2)  # (batch_size, target_seq_len, d_model)

    return out3, attn_weights_block1, attn_weights_block2
sample_decoder_layer = DecoderLayer(d_model=512, num_heads=8, dff=2048)

sample_decoder_layer_output, _, _ = sample_decoder_layer(
    tf.random.uniform((64, 50, 512)), sample_encoder_layer_output,
    False, None, None)

sample_decoder_layer_output.shape  # (batch_size, target_seq_len, d_model)
TensorShape([64, 50, 512])

Encoder

The Encoder consists of:

  1. Input Embedding
  2. Positional Encoding
  3. N encoder layers

The input is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the encoder layers. The output of the encoder is the input to the decoder.

class Encoder(tf.keras.layers.Layer):
  def __init__(self,*, num_layers, d_model, num_heads, dff, input_vocab_size,
               rate=0.1):
    super(Encoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
    self.pos_encoding = positional_encoding(MAX_TOKENS, self.d_model)

    self.enc_layers = [
        EncoderLayer(d_model=d_model, num_heads=num_heads, dff=dff, rate=rate)
        for _ in range(num_layers)]

    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    seq_len = tf.shape(x)[1]

    # adding embedding and position encoding.
    x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x = self.enc_layers[i](x, training, mask)

    return x  # (batch_size, input_seq_len, d_model)
sample_encoder = Encoder(num_layers=2, d_model=512, num_heads=8,
                         dff=2048, input_vocab_size=8500)
temp_input = tf.random.uniform((64, 62), dtype=tf.int64, minval=0, maxval=200)

sample_encoder_output = sample_encoder(temp_input, training=False, mask=None)

print(sample_encoder_output.shape)  # (batch_size, input_seq_len, d_model)
(64, 62, 512)

Decoder

The Decoder consists of:

  1. Output Embedding
  2. Positional Encoding
  3. N decoder layers

The target is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the decoder layers. The output of the decoder is the input to the final linear layer.

class Decoder(tf.keras.layers.Layer):
  def __init__(self,*, num_layers, d_model, num_heads, dff, target_vocab_size,
               rate=0.1):
    super(Decoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
    self.pos_encoding = positional_encoding(MAX_TOKENS, d_model)

    self.dec_layers = [
        DecoderLayer(d_model=d_model, num_heads=num_heads, dff=dff, rate=rate)
        for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, enc_output, training,
           look_ahead_mask, padding_mask):

    seq_len = tf.shape(x)[1]
    attention_weights = {}

    x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x, block1, block2 = self.dec_layers[i](x, enc_output, training,
                                             look_ahead_mask, padding_mask)

      attention_weights[f'decoder_layer{i+1}_block1'] = block1
      attention_weights[f'decoder_layer{i+1}_block2'] = block2

    # x.shape == (batch_size, target_seq_len, d_model)
    return x, attention_weights
sample_decoder = Decoder(num_layers=2, d_model=512, num_heads=8,
                         dff=2048, target_vocab_size=8000)
temp_input = tf.random.uniform((64, 26), dtype=tf.int64, minval=0, maxval=200)

output, attn = sample_decoder(temp_input,
                              enc_output=sample_encoder_output,
                              training=False,
                              look_ahead_mask=None,
                              padding_mask=None)

output.shape, attn['decoder_layer2_block2'].shape
(TensorShape([64, 26, 512]), TensorShape([64, 8, 26, 62]))

Create the transformer model

A transformer consists of the encoder, decoder, and a final linear layer. The output of the decoder is the input to the linear layer and its output is returned.

class Transformer(tf.keras.Model):
  def __init__(self,*, num_layers, d_model, num_heads, dff, input_vocab_size,
               target_vocab_size, rate=0.1):
    super().__init__()
    self.encoder = Encoder(num_layers=num_layers, d_model=d_model,
                           num_heads=num_heads, dff=dff,
                           input_vocab_size=input_vocab_size, rate=rate)

    self.decoder = Decoder(num_layers=num_layers, d_model=d_model,
                           num_heads=num_heads, dff=dff,
                           target_vocab_size=target_vocab_size, rate=rate)

    self.final_layer = tf.keras.layers.Dense(target_vocab_size)

  def call(self, inputs, training):
    # Keras models prefer if you pass all your inputs in the first argument
    inp, tar = inputs

    padding_mask, look_ahead_mask = self.create_masks(inp, tar)

    enc_output = self.encoder(inp, training, padding_mask)  # (batch_size, inp_seq_len, d_model)

    # dec_output.shape == (batch_size, tar_seq_len, d_model)
    dec_output, attention_weights = self.decoder(
        tar, enc_output, training, look_ahead_mask, padding_mask)

    final_output = self.final_layer(dec_output)  # (batch_size, tar_seq_len, target_vocab_size)

    return final_output, attention_weights

  def create_masks(self, inp, tar):
    # Encoder padding mask (Used in the 2nd attention block in the decoder too.)
    padding_mask = create_padding_mask(inp)

    # Used in the 1st attention block in the decoder.
    # It is used to pad and mask future tokens in the input received by
    # the decoder.
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    look_ahead_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

    return padding_mask, look_ahead_mask
sample_transformer = Transformer(
    num_layers=2, d_model=512, num_heads=8, dff=2048,
    input_vocab_size=8500, target_vocab_size=8000)

temp_input = tf.random.uniform((64, 38), dtype=tf.int64, minval=0, maxval=200)
temp_target = tf.random.uniform((64, 36), dtype=tf.int64, minval=0, maxval=200)

fn_out, _ = sample_transformer([temp_input, temp_target], training=False)

fn_out.shape  # (batch_size, tar_seq_len, target_vocab_size)
TensorShape([64, 36, 8000])

Set hyperparameters

To keep this example small and relatively fast, the values for num_layers, d_model, dff have been reduced.

The base model described in the paper used: num_layers=6, d_model=512, dff=2048.

num_layers = 4
d_model = 128
dff = 512
num_heads = 8
dropout_rate = 0.1

Optimizer

Use the Adam optimizer with a custom learning rate scheduler according to the formula in the paper.

\[\Large{lrate = d_{model}^{-0.5} * \min(step{\_}num^{-0.5}, step{\_}num \cdot warmup{\_}steps^{-1.5})}\]

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super(CustomSchedule, self).__init__()

    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    self.warmup_steps = warmup_steps

  def __call__(self, step):
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)

    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)
temp_learning_rate_schedule = CustomSchedule(d_model)

plt.plot(temp_learning_rate_schedule(tf.range(40000, dtype=tf.float32)))
plt.ylabel('Learning Rate')
plt.xlabel('Train Step')
Text(0.5, 0, 'Train Step')

png

Loss and metrics

Since the target sequences are padded, it is important to apply a padding mask when calculating the loss.

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')
def loss_function(real, pred):
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  return tf.reduce_sum(loss_)/tf.reduce_sum(mask)


def accuracy_function(real, pred):
  accuracies = tf.equal(real, tf.argmax(pred, axis=2))

  mask = tf.math.logical_not(tf.math.equal(real, 0))
  accuracies = tf.math.logical_and(mask, accuracies)

  accuracies = tf.cast(accuracies, dtype=tf.float32)
  mask = tf.cast(mask, dtype=tf.float32)
  return tf.reduce_sum(accuracies)/tf.reduce_sum(mask)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')

Training and checkpointing

transformer = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=tokenizers.pt.get_vocab_size().numpy(),
    target_vocab_size=tokenizers.en.get_vocab_size().numpy(),
    rate=dropout_rate)

Create the checkpoint path and the checkpoint manager. This will be used to save checkpoints every n epochs.

checkpoint_path = './checkpoints/train'

ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print('Latest checkpoint restored!!')

The target is divided into tar_inp and tar_real. tar_inp is passed as an input to the decoder. tar_real is that same input shifted by 1: At each location in tar_input, tar_real contains the next token that should be predicted.

For example, sentence = 'SOS A lion in the jungle is sleeping EOS' becomes:

  • tar_inp = 'SOS A lion in the jungle is sleeping'
  • tar_real = 'A lion in the jungle is sleeping EOS'

A transformer is an auto-regressive model: it makes predictions one part at a time, and uses its output so far to decide what to do next.

During training this example uses teacher-forcing (like in the text generation tutorial). Teacher forcing is passing the true output to the next time step regardless of what the model predicts at the current time step.

As the model predicts each token, self-attention allows it to look at the previous tokens in the input sequence to better predict the next token.

To prevent the model from peeking at the expected output the model uses a look-ahead mask.

EPOCHS = 20
# The @tf.function trace-compiles train_step into a TF graph for faster
# execution. The function specializes to the precise shape of the argument
# tensors. To avoid re-tracing due to the variable sequence lengths or variable
# batch sizes (the last batch is smaller), use input_signature to specify
# more generic shapes.

train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]


@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
  tar_inp = tar[:, :-1]
  tar_real = tar[:, 1:]

  with tf.GradientTape() as tape:
    predictions, _ = transformer([inp, tar_inp],
                                 training = True)
    loss = loss_function(tar_real, predictions)

  gradients = tape.gradient(loss, transformer.trainable_variables)
  optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

  train_loss(loss)
  train_accuracy(accuracy_function(tar_real, predictions))

Portuguese is used as the input language and English is the target language.

for epoch in range(EPOCHS):
  start = time.time()

  train_loss.reset_states()
  train_accuracy.reset_states()

  # inp -> portuguese, tar -> english
  for (batch, (inp, tar)) in enumerate(train_batches):
    train_step(inp, tar)

    if batch % 50 == 0:
      print(f'Epoch {epoch + 1} Batch {batch} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print(f'Saving checkpoint for epoch {epoch+1} at {ckpt_save_path}')

  print(f'Epoch {epoch + 1} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')

  print(f'Time taken for 1 epoch: {time.time() - start:.2f} secs\n')
Epoch 1 Batch 0 Loss 8.8531 Accuracy 0.0000
Epoch 1 Batch 50 Loss 8.7949 Accuracy 0.0032
Epoch 1 Batch 100 Loss 8.7005 Accuracy 0.0231
Epoch 1 Batch 150 Loss 8.5879 Accuracy 0.0324
Epoch 1 Batch 200 Loss 8.4469 Accuracy 0.0393
Epoch 1 Batch 250 Loss 8.2763 Accuracy 0.0459
Epoch 1 Batch 300 Loss 8.0874 Accuracy 0.0540
Epoch 1 Batch 350 Loss 7.8902 Accuracy 0.0612
Epoch 1 Batch 400 Loss 7.7009 Accuracy 0.0680
Epoch 1 Batch 450 Loss 7.5321 Accuracy 0.0747
Epoch 1 Batch 500 Loss 7.3834 Accuracy 0.0813
Epoch 1 Batch 550 Loss 7.2496 Accuracy 0.0876
Epoch 1 Batch 600 Loss 7.1250 Accuracy 0.0947
Epoch 1 Batch 650 Loss 7.0107 Accuracy 0.1017
Epoch 1 Batch 700 Loss 6.9043 Accuracy 0.1079
Epoch 1 Loss 6.9043 Accuracy 0.1079
Time taken for 1 epoch: 51.80 secs

Epoch 2 Batch 0 Loss 5.3989 Accuracy 0.1918
Epoch 2 Batch 50 Loss 5.4114 Accuracy 0.1993
Epoch 2 Batch 100 Loss 5.3776 Accuracy 0.2016
Epoch 2 Batch 150 Loss 5.3417 Accuracy 0.2039
Epoch 2 Batch 200 Loss 5.3033 Accuracy 0.2074
Epoch 2 Batch 250 Loss 5.2686 Accuracy 0.2102
Epoch 2 Batch 300 Loss 5.2359 Accuracy 0.2137
Epoch 2 Batch 350 Loss 5.2048 Accuracy 0.2169
Epoch 2 Batch 400 Loss 5.1783 Accuracy 0.2193
Epoch 2 Batch 450 Loss 5.1516 Accuracy 0.2221
Epoch 2 Batch 500 Loss 5.1249 Accuracy 0.2248
Epoch 2 Batch 550 Loss 5.1024 Accuracy 0.2267
Epoch 2 Batch 600 Loss 5.0822 Accuracy 0.2288
Epoch 2 Batch 650 Loss 5.0623 Accuracy 0.2307
Epoch 2 Batch 700 Loss 5.0416 Accuracy 0.2327
Epoch 2 Loss 5.0407 Accuracy 0.2328
Time taken for 1 epoch: 38.33 secs

Epoch 3 Batch 0 Loss 4.8084 Accuracy 0.2580
Epoch 3 Batch 50 Loss 4.7208 Accuracy 0.2600
Epoch 3 Batch 100 Loss 4.6989 Accuracy 0.2619
Epoch 3 Batch 150 Loss 4.6907 Accuracy 0.2635
Epoch 3 Batch 200 Loss 4.6770 Accuracy 0.2650
Epoch 3 Batch 250 Loss 4.6651 Accuracy 0.2661
Epoch 3 Batch 300 Loss 4.6500 Accuracy 0.2673
Epoch 3 Batch 350 Loss 4.6389 Accuracy 0.2683
Epoch 3 Batch 400 Loss 4.6215 Accuracy 0.2699
Epoch 3 Batch 450 Loss 4.6098 Accuracy 0.2708
Epoch 3 Batch 500 Loss 4.5990 Accuracy 0.2716
Epoch 3 Batch 550 Loss 4.5863 Accuracy 0.2727
Epoch 3 Batch 600 Loss 4.5743 Accuracy 0.2741
Epoch 3 Batch 650 Loss 4.5607 Accuracy 0.2756
Epoch 3 Batch 700 Loss 4.5493 Accuracy 0.2768
Epoch 3 Loss 4.5482 Accuracy 0.2769
Time taken for 1 epoch: 39.29 secs

Epoch 4 Batch 0 Loss 4.4872 Accuracy 0.2858
Epoch 4 Batch 50 Loss 4.3147 Accuracy 0.2991
Epoch 4 Batch 100 Loss 4.2780 Accuracy 0.3028
Epoch 4 Batch 150 Loss 4.2661 Accuracy 0.3049
Epoch 4 Batch 200 Loss 4.2495 Accuracy 0.3076
Epoch 4 Batch 250 Loss 4.2335 Accuracy 0.3102
Epoch 4 Batch 300 Loss 4.2165 Accuracy 0.3125
Epoch 4 Batch 350 Loss 4.1987 Accuracy 0.3148
Epoch 4 Batch 400 Loss 4.1804 Accuracy 0.3174
Epoch 4 Batch 450 Loss 4.1628 Accuracy 0.3195
Epoch 4 Batch 500 Loss 4.1445 Accuracy 0.3220
Epoch 4 Batch 550 Loss 4.1258 Accuracy 0.3242
Epoch 4 Batch 600 Loss 4.1078 Accuracy 0.3263
Epoch 4 Batch 650 Loss 4.0939 Accuracy 0.3281
Epoch 4 Loss 4.0790 Accuracy 0.3298
Time taken for 1 epoch: 39.41 secs

Epoch 5 Batch 0 Loss 3.7135 Accuracy 0.3579
Epoch 5 Batch 50 Loss 3.7548 Accuracy 0.3668
Epoch 5 Batch 100 Loss 3.7626 Accuracy 0.3662
Epoch 5 Batch 150 Loss 3.7450 Accuracy 0.3689
Epoch 5 Batch 200 Loss 3.7320 Accuracy 0.3708
Epoch 5 Batch 250 Loss 3.7225 Accuracy 0.3722
Epoch 5 Batch 300 Loss 3.7118 Accuracy 0.3738
Epoch 5 Batch 350 Loss 3.6993 Accuracy 0.3756
Epoch 5 Batch 400 Loss 3.6874 Accuracy 0.3777
Epoch 5 Batch 450 Loss 3.6730 Accuracy 0.3797
Epoch 5 Batch 500 Loss 3.6610 Accuracy 0.3813
Epoch 5 Batch 550 Loss 3.6504 Accuracy 0.3827
Epoch 5 Batch 600 Loss 3.6404 Accuracy 0.3842
Epoch 5 Batch 650 Loss 3.6284 Accuracy 0.3856
Saving checkpoint for epoch 5 at ./checkpoints/train/ckpt-1
Epoch 5 Loss 3.6206 Accuracy 0.3868
Time taken for 1 epoch: 39.17 secs

Epoch 6 Batch 0 Loss 3.2417 Accuracy 0.4175
Epoch 6 Batch 50 Loss 3.3316 Accuracy 0.4162
Epoch 6 Batch 100 Loss 3.3551 Accuracy 0.4148
Epoch 6 Batch 150 Loss 3.3538 Accuracy 0.4159
Epoch 6 Batch 200 Loss 3.3495 Accuracy 0.4161
Epoch 6 Batch 250 Loss 3.3375 Accuracy 0.4180
Epoch 6 Batch 300 Loss 3.3262 Accuracy 0.4202
Epoch 6 Batch 350 Loss 3.3167 Accuracy 0.4217
Epoch 6 Batch 400 Loss 3.3056 Accuracy 0.4234
Epoch 6 Batch 450 Loss 3.2983 Accuracy 0.4246
Epoch 6 Batch 500 Loss 3.2933 Accuracy 0.4252
Epoch 6 Batch 550 Loss 3.2836 Accuracy 0.4262
Epoch 6 Batch 600 Loss 3.2757 Accuracy 0.4276
Epoch 6 Batch 650 Loss 3.2657 Accuracy 0.4290
Epoch 6 Loss 3.2587 Accuracy 0.4302
Time taken for 1 epoch: 38.71 secs

Epoch 7 Batch 0 Loss 3.1061 Accuracy 0.4460
Epoch 7 Batch 50 Loss 3.0378 Accuracy 0.4538
Epoch 7 Batch 100 Loss 3.0175 Accuracy 0.4575
Epoch 7 Batch 150 Loss 2.9996 Accuracy 0.4596
Epoch 7 Batch 200 Loss 2.9934 Accuracy 0.4609
Epoch 7 Batch 250 Loss 2.9823 Accuracy 0.4625
Epoch 7 Batch 300 Loss 2.9793 Accuracy 0.4637
Epoch 7 Batch 350 Loss 2.9726 Accuracy 0.4643
Epoch 7 Batch 400 Loss 2.9654 Accuracy 0.4654
Epoch 7 Batch 450 Loss 2.9562 Accuracy 0.4668
Epoch 7 Batch 500 Loss 2.9461 Accuracy 0.4683
Epoch 7 Batch 550 Loss 2.9369 Accuracy 0.4696
Epoch 7 Batch 600 Loss 2.9327 Accuracy 0.4703
Epoch 7 Batch 650 Loss 2.9274 Accuracy 0.4711
Epoch 7 Loss 2.9200 Accuracy 0.4722
Time taken for 1 epoch: 38.81 secs

Epoch 8 Batch 0 Loss 2.5240 Accuracy 0.5190
Epoch 8 Batch 50 Loss 2.6982 Accuracy 0.4984
Epoch 8 Batch 100 Loss 2.6850 Accuracy 0.5006
Epoch 8 Batch 150 Loss 2.6801 Accuracy 0.5018
Epoch 8 Batch 200 Loss 2.6808 Accuracy 0.5025
Epoch 8 Batch 250 Loss 2.6759 Accuracy 0.5035
Epoch 8 Batch 300 Loss 2.6734 Accuracy 0.5039
Epoch 8 Batch 350 Loss 2.6681 Accuracy 0.5046
Epoch 8 Batch 400 Loss 2.6662 Accuracy 0.5048
Epoch 8 Batch 450 Loss 2.6609 Accuracy 0.5054
Epoch 8 Batch 500 Loss 2.6533 Accuracy 0.5064
Epoch 8 Batch 550 Loss 2.6456 Accuracy 0.5077
Epoch 8 Batch 600 Loss 2.6430 Accuracy 0.5080
Epoch 8 Batch 650 Loss 2.6394 Accuracy 0.5086
Epoch 8 Batch 700 Loss 2.6356 Accuracy 0.5094
Epoch 8 Loss 2.6349 Accuracy 0.5095
Time taken for 1 epoch: 39.05 secs

Epoch 9 Batch 0 Loss 2.4558 Accuracy 0.5307
Epoch 9 Batch 50 Loss 2.4526 Accuracy 0.5332
Epoch 9 Batch 100 Loss 2.4505 Accuracy 0.5325
Epoch 9 Batch 150 Loss 2.4423 Accuracy 0.5347
Epoch 9 Batch 200 Loss 2.4388 Accuracy 0.5351
Epoch 9 Batch 250 Loss 2.4391 Accuracy 0.5347
Epoch 9 Batch 300 Loss 2.4361 Accuracy 0.5354
Epoch 9 Batch 350 Loss 2.4345 Accuracy 0.5357
Epoch 9 Batch 400 Loss 2.4351 Accuracy 0.5358
Epoch 9 Batch 450 Loss 2.4312 Accuracy 0.5365
Epoch 9 Batch 500 Loss 2.4287 Accuracy 0.5367
Epoch 9 Batch 550 Loss 2.4247 Accuracy 0.5374
Epoch 9 Batch 600 Loss 2.4256 Accuracy 0.5372
Epoch 9 Batch 650 Loss 2.4243 Accuracy 0.5375
Epoch 9 Loss 2.4234 Accuracy 0.5378
Time taken for 1 epoch: 38.24 secs

Epoch 10 Batch 0 Loss 2.3818 Accuracy 0.5334
Epoch 10 Batch 50 Loss 2.2105 Accuracy 0.5651
Epoch 10 Batch 100 Loss 2.2428 Accuracy 0.5611
Epoch 10 Batch 150 Loss 2.2566 Accuracy 0.5597
Epoch 10 Batch 200 Loss 2.2575 Accuracy 0.5597
Epoch 10 Batch 250 Loss 2.2565 Accuracy 0.5597
Epoch 10 Batch 300 Loss 2.2574 Accuracy 0.5597
Epoch 10 Batch 350 Loss 2.2590 Accuracy 0.5596
Epoch 10 Batch 400 Loss 2.2587 Accuracy 0.5599
Epoch 10 Batch 450 Loss 2.2545 Accuracy 0.5607
Epoch 10 Batch 500 Loss 2.2550 Accuracy 0.5609
Epoch 10 Batch 550 Loss 2.2548 Accuracy 0.5609
Epoch 10 Batch 600 Loss 2.2558 Accuracy 0.5609
Epoch 10 Batch 650 Loss 2.2562 Accuracy 0.5609
Epoch 10 Batch 700 Loss 2.2566 Accuracy 0.5609
Saving checkpoint for epoch 10 at ./checkpoints/train/ckpt-2
Epoch 10 Loss 2.2561 Accuracy 0.5610
Time taken for 1 epoch: 38.66 secs

Epoch 11 Batch 0 Loss 2.0627 Accuracy 0.5842
Epoch 11 Batch 50 Loss 2.1071 Accuracy 0.5802
Epoch 11 Batch 100 Loss 2.0934 Accuracy 0.5819
Epoch 11 Batch 150 Loss 2.1149 Accuracy 0.5791
Epoch 11 Batch 200 Loss 2.1186 Accuracy 0.5793
Epoch 11 Batch 250 Loss 2.1142 Accuracy 0.5802
Epoch 11 Batch 300 Loss 2.1156 Accuracy 0.5800
Epoch 11 Batch 350 Loss 2.1223 Accuracy 0.5791
Epoch 11 Batch 400 Loss 2.1228 Accuracy 0.5791
Epoch 11 Batch 450 Loss 2.1228 Accuracy 0.5791
Epoch 11 Batch 500 Loss 2.1211 Accuracy 0.5795
Epoch 11 Batch 550 Loss 2.1180 Accuracy 0.5802
Epoch 11 Batch 600 Loss 2.1192 Accuracy 0.5802
Epoch 11 Batch 650 Loss 2.1197 Accuracy 0.5801
Epoch 11 Loss 2.1213 Accuracy 0.5800
Time taken for 1 epoch: 38.49 secs

Epoch 12 Batch 0 Loss 2.1750 Accuracy 0.5651
Epoch 12 Batch 50 Loss 2.0071 Accuracy 0.5937
Epoch 12 Batch 100 Loss 2.0088 Accuracy 0.5945
Epoch 12 Batch 150 Loss 2.0052 Accuracy 0.5952
Epoch 12 Batch 200 Loss 2.0037 Accuracy 0.5959
Epoch 12 Batch 250 Loss 2.0079 Accuracy 0.5950
Epoch 12 Batch 300 Loss 2.0078 Accuracy 0.5952
Epoch 12 Batch 350 Loss 2.0104 Accuracy 0.5947
Epoch 12 Batch 400 Loss 2.0109 Accuracy 0.5950
Epoch 12 Batch 450 Loss 2.0131 Accuracy 0.5948
Epoch 12 Batch 500 Loss 2.0128 Accuracy 0.5950
Epoch 12 Batch 550 Loss 2.0125 Accuracy 0.5951
Epoch 12 Batch 600 Loss 2.0141 Accuracy 0.5950
Epoch 12 Batch 650 Loss 2.0147 Accuracy 0.5951
Epoch 12 Loss 2.0151 Accuracy 0.5951
Time taken for 1 epoch: 38.48 secs

Epoch 13 Batch 0 Loss 1.8535 Accuracy 0.6231
Epoch 13 Batch 50 Loss 1.8855 Accuracy 0.6131
Epoch 13 Batch 100 Loss 1.9094 Accuracy 0.6097
Epoch 13 Batch 150 Loss 1.9151 Accuracy 0.6088
Epoch 13 Batch 200 Loss 1.9140 Accuracy 0.6084
Epoch 13 Batch 250 Loss 1.9135 Accuracy 0.6086
Epoch 13 Batch 300 Loss 1.9110 Accuracy 0.6090
Epoch 13 Batch 350 Loss 1.9140 Accuracy 0.6087
Epoch 13 Batch 400 Loss 1.9127 Accuracy 0.6091
Epoch 13 Batch 450 Loss 1.9143 Accuracy 0.6089
Epoch 13 Batch 500 Loss 1.9151 Accuracy 0.6089
Epoch 13 Batch 550 Loss 1.9166 Accuracy 0.6089
Epoch 13 Batch 600 Loss 1.9195 Accuracy 0.6088
Epoch 13 Batch 650 Loss 1.9208 Accuracy 0.6087
Epoch 13 Loss 1.9229 Accuracy 0.6084
Time taken for 1 epoch: 38.02 secs

Epoch 14 Batch 0 Loss 1.7991 Accuracy 0.6328
Epoch 14 Batch 50 Loss 1.8104 Accuracy 0.6226
Epoch 14 Batch 100 Loss 1.8140 Accuracy 0.6226
Epoch 14 Batch 150 Loss 1.8177 Accuracy 0.6217
Epoch 14 Batch 200 Loss 1.8203 Accuracy 0.6218
Epoch 14 Batch 250 Loss 1.8282 Accuracy 0.6208
Epoch 14 Batch 300 Loss 1.8306 Accuracy 0.6207
Epoch 14 Batch 350 Loss 1.8306 Accuracy 0.6211
Epoch 14 Batch 400 Loss 1.8294 Accuracy 0.6214
Epoch 14 Batch 450 Loss 1.8327 Accuracy 0.6210
Epoch 14 Batch 500 Loss 1.8367 Accuracy 0.6204
Epoch 14 Batch 550 Loss 1.8383 Accuracy 0.6201
Epoch 14 Batch 600 Loss 1.8396 Accuracy 0.6200
Epoch 14 Batch 650 Loss 1.8410 Accuracy 0.6199
Epoch 14 Batch 700 Loss 1.8458 Accuracy 0.6193
Epoch 14 Loss 1.8461 Accuracy 0.6192
Time taken for 1 epoch: 38.29 secs

Epoch 15 Batch 0 Loss 1.4491 Accuracy 0.6872
Epoch 15 Batch 50 Loss 1.7225 Accuracy 0.6397
Epoch 15 Batch 100 Loss 1.7232 Accuracy 0.6393
Epoch 15 Batch 150 Loss 1.7282 Accuracy 0.6374
Epoch 15 Batch 200 Loss 1.7416 Accuracy 0.6355
Epoch 15 Batch 250 Loss 1.7470 Accuracy 0.6346
Epoch 15 Batch 300 Loss 1.7537 Accuracy 0.6335
Epoch 15 Batch 350 Loss 1.7612 Accuracy 0.6320
Epoch 15 Batch 400 Loss 1.7613 Accuracy 0.6319
Epoch 15 Batch 450 Loss 1.7618 Accuracy 0.6318
Epoch 15 Batch 500 Loss 1.7660 Accuracy 0.6313
Epoch 15 Batch 550 Loss 1.7709 Accuracy 0.6308
Epoch 15 Batch 600 Loss 1.7735 Accuracy 0.6304
Epoch 15 Batch 650 Loss 1.7753 Accuracy 0.6302
Epoch 15 Batch 700 Loss 1.7784 Accuracy 0.6298
Saving checkpoint for epoch 15 at ./checkpoints/train/ckpt-3
Epoch 15 Loss 1.7783 Accuracy 0.6299
Time taken for 1 epoch: 38.42 secs

Epoch 16 Batch 0 Loss 1.6642 Accuracy 0.6440
Epoch 16 Batch 50 Loss 1.6642 Accuracy 0.6453
Epoch 16 Batch 100 Loss 1.6646 Accuracy 0.6460
Epoch 16 Batch 150 Loss 1.6753 Accuracy 0.6445
Epoch 16 Batch 200 Loss 1.6863 Accuracy 0.6434
Epoch 16 Batch 250 Loss 1.6979 Accuracy 0.6420
Epoch 16 Batch 300 Loss 1.7000 Accuracy 0.6416
Epoch 16 Batch 350 Loss 1.7015 Accuracy 0.6412
Epoch 16 Batch 400 Loss 1.7025 Accuracy 0.6411
Epoch 16 Batch 450 Loss 1.7016 Accuracy 0.6412
Epoch 16 Batch 500 Loss 1.7048 Accuracy 0.6407
Epoch 16 Batch 550 Loss 1.7072 Accuracy 0.6402
Epoch 16 Batch 600 Loss 1.7111 Accuracy 0.6396
Epoch 16 Batch 650 Loss 1.7135 Accuracy 0.6393
Epoch 16 Batch 700 Loss 1.7172 Accuracy 0.6387
Epoch 16 Loss 1.7178 Accuracy 0.6386
Time taken for 1 epoch: 38.71 secs

Epoch 17 Batch 0 Loss 1.6321 Accuracy 0.6582
Epoch 17 Batch 50 Loss 1.6246 Accuracy 0.6498
Epoch 17 Batch 100 Loss 1.6288 Accuracy 0.6508
Epoch 17 Batch 150 Loss 1.6266 Accuracy 0.6521
Epoch 17 Batch 200 Loss 1.6304 Accuracy 0.6516
Epoch 17 Batch 250 Loss 1.6324 Accuracy 0.6516
Epoch 17 Batch 300 Loss 1.6368 Accuracy 0.6505
Epoch 17 Batch 350 Loss 1.6414 Accuracy 0.6498
Epoch 17 Batch 400 Loss 1.6423 Accuracy 0.6498
Epoch 17 Batch 450 Loss 1.6446 Accuracy 0.6494
Epoch 17 Batch 500 Loss 1.6467 Accuracy 0.6490
Epoch 17 Batch 550 Loss 1.6521 Accuracy 0.6481
Epoch 17 Batch 600 Loss 1.6561 Accuracy 0.6477
Epoch 17 Batch 650 Loss 1.6605 Accuracy 0.6472
Epoch 17 Batch 700 Loss 1.6640 Accuracy 0.6467
Epoch 17 Loss 1.6651 Accuracy 0.6464
Time taken for 1 epoch: 38.36 secs

Epoch 18 Batch 0 Loss 1.7109 Accuracy 0.6334
Epoch 18 Batch 50 Loss 1.5447 Accuracy 0.6661
Epoch 18 Batch 100 Loss 1.5643 Accuracy 0.6621
Epoch 18 Batch 150 Loss 1.5763 Accuracy 0.6604
Epoch 18 Batch 200 Loss 1.5792 Accuracy 0.6600
Epoch 18 Batch 250 Loss 1.5860 Accuracy 0.6589
Epoch 18 Batch 300 Loss 1.5886 Accuracy 0.6585
Epoch 18 Batch 350 Loss 1.5887 Accuracy 0.6586
Epoch 18 Batch 400 Loss 1.5916 Accuracy 0.6578
Epoch 18 Batch 450 Loss 1.5932 Accuracy 0.6578
Epoch 18 Batch 500 Loss 1.5976 Accuracy 0.6570
Epoch 18 Batch 550 Loss 1.6019 Accuracy 0.6562
Epoch 18 Batch 600 Loss 1.6049 Accuracy 0.6559
Epoch 18 Batch 650 Loss 1.6076 Accuracy 0.6556
Epoch 18 Loss 1.6116 Accuracy 0.6551
Time taken for 1 epoch: 38.72 secs

Epoch 19 Batch 0 Loss 1.3450 Accuracy 0.7062
Epoch 19 Batch 50 Loss 1.5109 Accuracy 0.6695
Epoch 19 Batch 100 Loss 1.5241 Accuracy 0.6668
Epoch 19 Batch 150 Loss 1.5378 Accuracy 0.6658
Epoch 19 Batch 200 Loss 1.5423 Accuracy 0.6653
Epoch 19 Batch 250 Loss 1.5515 Accuracy 0.6639
Epoch 19 Batch 300 Loss 1.5542 Accuracy 0.6637
Epoch 19 Batch 350 Loss 1.5553 Accuracy 0.6637
Epoch 19 Batch 400 Loss 1.5595 Accuracy 0.6632
Epoch 19 Batch 450 Loss 1.5605 Accuracy 0.6627
Epoch 19 Batch 500 Loss 1.5626 Accuracy 0.6624
Epoch 19 Batch 550 Loss 1.5645 Accuracy 0.6622
Epoch 19 Batch 600 Loss 1.5692 Accuracy 0.6614
Epoch 19 Batch 650 Loss 1.5715 Accuracy 0.6612
Epoch 19 Batch 700 Loss 1.5733 Accuracy 0.6609
Epoch 19 Loss 1.5733 Accuracy 0.6609
Time taken for 1 epoch: 38.43 secs

Epoch 20 Batch 0 Loss 1.5804 Accuracy 0.6560
Epoch 20 Batch 50 Loss 1.4986 Accuracy 0.6732
Epoch 20 Batch 100 Loss 1.4969 Accuracy 0.6729
Epoch 20 Batch 150 Loss 1.5063 Accuracy 0.6710
Epoch 20 Batch 200 Loss 1.5107 Accuracy 0.6704
Epoch 20 Batch 250 Loss 1.5119 Accuracy 0.6701
Epoch 20 Batch 300 Loss 1.5132 Accuracy 0.6700
Epoch 20 Batch 350 Loss 1.5146 Accuracy 0.6699
Epoch 20 Batch 400 Loss 1.5157 Accuracy 0.6696
Epoch 20 Batch 450 Loss 1.5170 Accuracy 0.6694
Epoch 20 Batch 500 Loss 1.5196 Accuracy 0.6690
Epoch 20 Batch 550 Loss 1.5228 Accuracy 0.6684
Epoch 20 Batch 600 Loss 1.5269 Accuracy 0.6679
Epoch 20 Batch 650 Loss 1.5290 Accuracy 0.6674
Epoch 20 Batch 700 Loss 1.5324 Accuracy 0.6670
Saving checkpoint for epoch 20 at ./checkpoints/train/ckpt-4
Epoch 20 Loss 1.5323 Accuracy 0.6671
Time taken for 1 epoch: 38.24 secs

Run inference

The following steps are used for inference:

  • Encode the input sentence using the Portuguese tokenizer (tokenizers.pt). This is the encoder input.
  • The decoder input is initialized to the [START] token.
  • Calculate the padding masks and the look ahead masks.
  • The decoder then outputs the predictions by looking at the encoder output and its own output (self-attention).
  • Concatenate the predicted token to the decoder input and pass it to the decoder.
  • In this approach, the decoder predicts the next token based on the previous tokens it predicted.
class Translator(tf.Module):
  def __init__(self, tokenizers, transformer):
    self.tokenizers = tokenizers
    self.transformer = transformer

  def __call__(self, sentence, max_length=MAX_TOKENS):
    # input sentence is portuguese, hence adding the start and end token
    assert isinstance(sentence, tf.Tensor)
    if len(sentence.shape) == 0:
      sentence = sentence[tf.newaxis]

    sentence = self.tokenizers.pt.tokenize(sentence).to_tensor()

    encoder_input = sentence

    # As the output language is english, initialize the output with the
    # english start token.
    start_end = self.tokenizers.en.tokenize([''])[0]
    start = start_end[0][tf.newaxis]
    end = start_end[1][tf.newaxis]

    # `tf.TensorArray` is required here (instead of a python list) so that the
    # dynamic-loop can be traced by `tf.function`.
    output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
    output_array = output_array.write(0, start)

    for i in tf.range(max_length):
      output = tf.transpose(output_array.stack())
      predictions, _ = self.transformer([encoder_input, output], training=False)

      # select the last token from the seq_len dimension
      predictions = predictions[:, -1:, :]  # (batch_size, 1, vocab_size)

      predicted_id = tf.argmax(predictions, axis=-1)

      # concatentate the predicted_id to the output which is given to the decoder
      # as its input.
      output_array = output_array.write(i+1, predicted_id[0])

      if predicted_id == end:
        break

    output = tf.transpose(output_array.stack())
    # output.shape (1, tokens)
    text = tokenizers.en.detokenize(output)[0]  # shape: ()

    tokens = tokenizers.en.lookup(output)[0]

    # `tf.function` prevents us from using the attention_weights that were
    # calculated on the last iteration of the loop. So recalculate them outside
    # the loop.
    _, attention_weights = self.transformer([encoder_input, output[:,:-1]], training=False)

    return text, tokens, attention_weights

Create an instance of this Translator class, and try it out a few times:

translator = Translator(tokenizers, transformer)
def print_translation(sentence, tokens, ground_truth):
  print(f'{"Input:":15s}: {sentence}')
  print(f'{"Prediction":15s}: {tokens.numpy().decode("utf-8")}')
  print(f'{"Ground truth":15s}: {ground_truth}')
sentence = 'este é um problema que temos que resolver.'
ground_truth = 'this is a problem we have to solve .'

translated_text, translated_tokens, attention_weights = translator(
    tf.constant(sentence))
print_translation(sentence, translated_text, ground_truth)
Input:         : este é um problema que temos que resolver.
Prediction     : this is a problem that we have to solve .
Ground truth   : this is a problem we have to solve .
sentence = 'os meus vizinhos ouviram sobre esta ideia.'
ground_truth = 'and my neighboring homes heard about this idea .'

translated_text, translated_tokens, attention_weights = translator(
    tf.constant(sentence))
print_translation(sentence, translated_text, ground_truth)
Input:         : os meus vizinhos ouviram sobre esta ideia.
Prediction     : my neighbors heard about this idea .
Ground truth   : and my neighboring homes heard about this idea .
sentence = 'vou então muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram.'
ground_truth = "so i'll just share with you some stories very quickly of some magical things that have happened."

translated_text, translated_tokens, attention_weights = translator(
    tf.constant(sentence))
print_translation(sentence, translated_text, ground_truth)
Input:         : vou então muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram.
Prediction     : so i ' ll really share with you with some stories of some magical things that happened .
Ground truth   : so i'll just share with you some stories very quickly of some magical things that have happened.

Attention plots

The Translator class returns a dictionary of attention maps you can use to visualize the internal working of the model:

sentence = 'este é o primeiro livro que eu fiz.'
ground_truth = "this is the first book i've ever done."

translated_text, translated_tokens, attention_weights = translator(
    tf.constant(sentence))
print_translation(sentence, translated_text, ground_truth)
Input:         : este é o primeiro livro que eu fiz.
Prediction     : this is the first book i did .
Ground truth   : this is the first book i've ever done.
def plot_attention_head(in_tokens, translated_tokens, attention):
  # The plot is of the attention when a token was generated.
  # The model didn't generate `<START>` in the output. Skip it.
  translated_tokens = translated_tokens[1:]

  ax = plt.gca()
  ax.matshow(attention)
  ax.set_xticks(range(len(in_tokens)))
  ax.set_yticks(range(len(translated_tokens)))

  labels = [label.decode('utf-8') for label in in_tokens.numpy()]
  ax.set_xticklabels(
      labels, rotation=90)

  labels = [label.decode('utf-8') for label in translated_tokens.numpy()]
  ax.set_yticklabels(labels)
head = 0
# shape: (batch=1, num_heads, seq_len_q, seq_len_k)
attention_heads = tf.squeeze(
  attention_weights['decoder_layer4_block2'], 0)
attention = attention_heads[head]
attention.shape
TensorShape([9, 11])
in_tokens = tf.convert_to_tensor([sentence])
in_tokens = tokenizers.pt.tokenize(in_tokens).to_tensor()
in_tokens = tokenizers.pt.lookup(in_tokens)[0]
in_tokens
<tf.Tensor: shape=(11,), dtype=string, numpy=
array([b'[START]', b'este', b'e', b'o', b'primeiro', b'livro', b'que',
       b'eu', b'fiz', b'.', b'[END]'], dtype=object)>
translated_tokens
<tf.Tensor: shape=(10,), dtype=string, numpy=
array([b'[START]', b'this', b'is', b'the', b'first', b'book', b'i',
       b'did', b'.', b'[END]'], dtype=object)>
plot_attention_head(in_tokens, translated_tokens, attention)

png

def plot_attention_weights(sentence, translated_tokens, attention_heads):
  in_tokens = tf.convert_to_tensor([sentence])
  in_tokens = tokenizers.pt.tokenize(in_tokens).to_tensor()
  in_tokens = tokenizers.pt.lookup(in_tokens)[0]
  in_tokens

  fig = plt.figure(figsize=(16, 8))

  for h, head in enumerate(attention_heads):
    ax = fig.add_subplot(2, 4, h+1)

    plot_attention_head(in_tokens, translated_tokens, head)

    ax.set_xlabel(f'Head {h+1}')

  plt.tight_layout()
  plt.show()
plot_attention_weights(sentence, translated_tokens,
                       attention_weights['decoder_layer4_block2'][0])

png

The model does okay on unfamiliar words. Neither "triceratops" or "encyclopedia" are in the input dataset and the model almost learns to transliterate them, even without a shared vocabulary:

sentence = 'Eu li sobre triceratops na enciclopédia.'
ground_truth = 'I read about triceratops in the encyclopedia.'

translated_text, translated_tokens, attention_weights = translator(
    tf.constant(sentence))
print_translation(sentence, translated_text, ground_truth)

plot_attention_weights(sentence, translated_tokens,
                       attention_weights['decoder_layer4_block2'][0])
Input:         : Eu li sobre triceratops na enciclopédia.
Prediction     : i read about triopatrinas in esclodies .
Ground truth   : I read about triceratops in the encyclopedia.

png

Export

That inference model is working, so next you'll export it as a tf.saved_model.

To do that, wrap it in yet another tf.Module sub-class, this time with a tf.function on the __call__ method:

class ExportTranslator(tf.Module):
  def __init__(self, translator):
    self.translator = translator

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def __call__(self, sentence):
    (result,
     tokens,
     attention_weights) = self.translator(sentence, max_length=MAX_TOKENS)

    return result

In the above tf.function only the output sentence is returned. Thanks to the non-strict execution in tf.function any unnecessary values are never computed.

translator = ExportTranslator(translator)

Since the model is decoding the predictions using tf.argmax the predictions are deterministic. The original model and one reloaded from its SavedModel should give identical predictions:

translator('este é o primeiro livro que eu fiz.').numpy()
b'this is the first book i did .'
tf.saved_model.save(translator, export_dir='translator')
2022-05-04 11:21:55.800781: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Found untraced functions such as embedding_4_layer_call_fn, embedding_4_layer_call_and_return_conditional_losses, dropout_37_layer_call_fn, dropout_37_layer_call_and_return_conditional_losses, embedding_5_layer_call_fn while saving (showing 5 of 224). These functions will not be directly callable after loading.
reloaded = tf.saved_model.load('translator')
reloaded('este é o primeiro livro que eu fiz.').numpy()
b'this is the first book i did .'

Summary

In this tutorial you learned about:

  • positional encoding
  • multi-head attention
  • the importance of masking
  • and how to put it all together to build a transformer.

This implementation tried to stay close to the implementation of the original paper. If you want to practice there are many things you could try with it. For example:

  • Using a different dataset to train the transformer.
  • Create the "Base Transformer" or "Transformer XL" configurations from the original paper by changing the hyperparameters.
  • Use the layers defined here to create an implementation of BERT.
  • Implement beam search to get better predictions.