Warm-start embedding layer matrix

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

This tutorial shows how to "warm-start" training using the tf.keras.utils.warmstart_embedding_matrix API for text sentiment classification when changing vocabulary.

You will begin by training a simple Keras model with a base vocabulary, and then, after updating the vocabulary, continue training the model. This is referred to as "warm-start" training, for which you'll need to remap the text-embedding matrix for the new vocabulary.

Embedding matrix

Embeddings provide a way to use an efficient, dense representation in which similar vocabulary tokens have a similar encoding. They are trainable parameters (weights learned by the model during training, in the same way a model learns weights for a dense layer). It is common to have embeddings that are 8-dimensional for small datasets, and up to 1024-dimensions when working with large datasets. A higher dimensional embedding can capture fine-grained relationships between words, but can take more data to learn.

Vocabulary

The set of unique words is referred to as the vocabulary. To build a text model you need to choose a fixed vocabulary. Typically you build the vocabulary from the most common words in a dataset. The vocabulary allows us to represent each piece of text by a sequence of ID's that you can lookup in the embedding matrix. Vocabulary allows us to represent each piece of text by the specific words that appear in it.

Why warm-start an embedding matrix?

A model is trained with a set of embeddings that represents a given vocabulary. If the model needs to be updated or improved you can train to convergence significantly faster by reusing weights from a previous run. Using the embedding matrix from a previous run is more difficult. The problem is that any change to the vocabulary invalidates the word to id mapping.

The tf.keras.utils.warmstart_embedding_matrix solves this problem by creating an embedding matrix for a new vocabulary from an embedding matrix from a base vocabulary. Where a word exists in both vocabularies the base embedding vector is copied into the correct location in the new embedding matrix. This allows you to warm-start training after any change in the size or order of the vocabulary.

Setup

pip install --pre -U "tensorflow>2.10"  # Requires 2.11
import io
import numpy as np
import os
import re
import shutil
import string
import tensorflow as tf

from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Embedding, GlobalAveragePooling1D
from tensorflow.keras.layers import TextVectorization
2023-11-16 14:06:43.000821: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-16 14:06:43.000870: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-16 14:06:43.002412: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

Load the dataset

The tutorial uses the Large Movie Review Dataset. You will train a sentiment classifier model on this dataset and in the process learn embeddings from scratch. Refer to the Loading text tutorial to learn more.

Download the dataset using Keras file utility and review the directories.

url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"

dataset = tf.keras.utils.get_file(
    "aclImdb_v1.tar.gz", url, untar=True, cache_dir=".", cache_subdir=""
)

dataset_dir = os.path.join(os.path.dirname(dataset), "aclImdb")
os.listdir(dataset_dir)
Downloading data from https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
84125825/84125825 [==============================] - 4s 0us/step
['train', 'README', 'imdb.vocab', 'test', 'imdbEr.txt']

The train/ directory has pos and neg folders with movie reviews labeled as positive and negative respectively. You will use reviews from pos and neg folders to train a binary classification model.

train_dir = os.path.join(dataset_dir, "train")
os.listdir(train_dir)
['urls_unsup.txt',
 'unsupBow.feat',
 'unsup',
 'pos',
 'labeledBow.feat',
 'neg',
 'urls_pos.txt',
 'urls_neg.txt']

The train directory also contains additional folders which should be removed before creating the training set.

remove_dir = os.path.join(train_dir, "unsup")
shutil.rmtree(remove_dir)

Next, create a tf.data.Dataset using tf.keras.utils.text_dataset_from_directory. You can read more about using this utility in this text classification tutorial.

Use the train directory to create the training and validation sets with a split of 20% for validation.

batch_size = 1024
seed = 123
train_ds = tf.keras.utils.text_dataset_from_directory(
    "aclImdb/train",
    batch_size=batch_size,
    validation_split=0.2,
    subset="training",
    seed=seed,
)
val_ds = tf.keras.utils.text_dataset_from_directory(
    "aclImdb/train",
    batch_size=batch_size,
    validation_split=0.2,
    subset="validation",
    seed=seed,
)
Found 25000 files belonging to 2 classes.
Using 20000 files for training.
Found 25000 files belonging to 2 classes.
Using 5000 files for validation.

Configure the dataset for performance

You can learn more about Dataset.cache and Dataset.prefetch, as well as how to cache data to disk in the data performance guide.

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

Text preprocessing

Next, define the dataset preprocessing steps required for your sentiment classification model. Initialize a layers.TextVectorization layer with the desired parameters to vectorize movie reviews. You can learn more about using this layer in the Text Classification tutorial.

# Create a custom standardization function to strip HTML break tags '<br />'.
def custom_standardization(input_data):
    lowercase = tf.strings.lower(input_data)
    stripped_html = tf.strings.regex_replace(lowercase, "<br />", " ")
    return tf.strings.regex_replace(
        stripped_html, "[%s]" % re.escape(string.punctuation), ""
    )


# Vocabulary size and number of words in a sequence.
vocab_size = 10000
sequence_length = 100

# Use the text vectorization layer to normalize, split, and map strings to
# integers. Note that the layer uses the custom standardization defined above.
# Set maximum_sequence length as all samples are not of the same length.
vectorize_layer = TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size,
    output_mode="int",
    output_sequence_length=sequence_length,
)

# Make a text-only dataset (no labels) and call `Dataset.adapt` to build the
# vocabulary.
text_ds = train_ds.map(lambda x, y: x)
vectorize_layer.adapt(text_ds)

Create a classification model

Use the Keras Sequential API to define the sentiment classification model.

embedding_dim = 16
text_embedding = Embedding(vocab_size, embedding_dim, name="embedding")
text_input = tf.keras.Sequential(
    [vectorize_layer, text_embedding], name="text_input"
)
classifier_head = tf.keras.Sequential(
    [GlobalAveragePooling1D(), Dense(16, activation="relu"), Dense(1)],
    name="classifier_head",
)

model = tf.keras.Sequential([text_input, classifier_head])

Compile and train the model

You will use TensorBoard to visualize metrics including loss and accuracy. Create a tf.keras.callbacks.TensorBoard.

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="logs")

Compile and train the model using the Adam optimizer and BinaryCrossentropy loss.

model.compile(
    optimizer="adam",
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=["accuracy"],
)
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=15,
    callbacks=[tensorboard_callback],
)
Epoch 1/15
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1700143637.253509  100402 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
20/20 [==============================] - 6s 199ms/step - loss: 0.6918 - accuracy: 0.5028 - val_loss: 0.6899 - val_accuracy: 0.4886
Epoch 2/15
20/20 [==============================] - 1s 52ms/step - loss: 0.6868 - accuracy: 0.5028 - val_loss: 0.6834 - val_accuracy: 0.4886
Epoch 3/15
20/20 [==============================] - 1s 54ms/step - loss: 0.6786 - accuracy: 0.5028 - val_loss: 0.6736 - val_accuracy: 0.4886
Epoch 4/15
20/20 [==============================] - 1s 53ms/step - loss: 0.6658 - accuracy: 0.5028 - val_loss: 0.6581 - val_accuracy: 0.4886
Epoch 5/15
20/20 [==============================] - 1s 53ms/step - loss: 0.6465 - accuracy: 0.5028 - val_loss: 0.6382 - val_accuracy: 0.4886
Epoch 6/15
20/20 [==============================] - 1s 52ms/step - loss: 0.6231 - accuracy: 0.5150 - val_loss: 0.6150 - val_accuracy: 0.5310
Epoch 7/15
20/20 [==============================] - 1s 52ms/step - loss: 0.5957 - accuracy: 0.5874 - val_loss: 0.5888 - val_accuracy: 0.5966
Epoch 8/15
20/20 [==============================] - 1s 55ms/step - loss: 0.5654 - accuracy: 0.6575 - val_loss: 0.5616 - val_accuracy: 0.6518
Epoch 9/15
20/20 [==============================] - 1s 52ms/step - loss: 0.5340 - accuracy: 0.7124 - val_loss: 0.5347 - val_accuracy: 0.6902
Epoch 10/15
20/20 [==============================] - 1s 53ms/step - loss: 0.5031 - accuracy: 0.7491 - val_loss: 0.5096 - val_accuracy: 0.7236
Epoch 11/15
20/20 [==============================] - 1s 54ms/step - loss: 0.4740 - accuracy: 0.7740 - val_loss: 0.4872 - val_accuracy: 0.7432
Epoch 12/15
20/20 [==============================] - 1s 52ms/step - loss: 0.4474 - accuracy: 0.7945 - val_loss: 0.4677 - val_accuracy: 0.7572
Epoch 13/15
20/20 [==============================] - 1s 52ms/step - loss: 0.4234 - accuracy: 0.8110 - val_loss: 0.4510 - val_accuracy: 0.7664
Epoch 14/15
20/20 [==============================] - 1s 52ms/step - loss: 0.4020 - accuracy: 0.8227 - val_loss: 0.4370 - val_accuracy: 0.7748
Epoch 15/15
20/20 [==============================] - 1s 52ms/step - loss: 0.3829 - accuracy: 0.8339 - val_loss: 0.4251 - val_accuracy: 0.7858
<keras.src.callbacks.History at 0x7f4c8d82ebe0>

With this approach the model reaches a validation accuracy of around 85%

You can look into the model summary to learn more about each layer of the model.

model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 text_input (Sequential)     (None, 100, 16)           160000    
                                                                 
 classifier_head (Sequentia  (None, 1)                 289       
 l)                                                              
                                                                 
=================================================================
Total params: 160289 (626.13 KB)
Trainable params: 160289 (626.13 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

Visualize the model metrics in TensorBoard.

# docs_infra: no_execute
%load_ext tensorboard
%tensorboard --logdir logs

Vocabulary remapping

Now you're going to update the vocabulary and continue with warm-started training.

First, get the base vocabulary and embedding matrix.

embedding_weights_base = (
    model.get_layer("text_input").get_layer("embedding").embeddings
)
vocab_base = vectorize_layer.get_vocabulary()

Define a new vectorization layer to generate a new bigger vocabulary

# Vocabulary size and number of words in a sequence.
vocab_size_new = 10200
sequence_length = 100

vectorize_layer_new = TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size_new,
    output_mode="int",
    output_sequence_length=sequence_length,
)

# Make a text-only dataset (no labels) and call adapt to build the vocabulary.
text_ds = train_ds.map(lambda x, y: x)
vectorize_layer_new.adapt(text_ds)

# Get the new vocabulary
vocab_new = vectorize_layer_new.get_vocabulary()
# View the new vocabulary tokens that weren't in `vocab_base`
set(vocab_base) ^ set(vocab_new)
{'bullying',
 'bumps',
 'canvas',
 'carole',
 'chains',
 'chairman',
 'checks',
 'coarse',
 'competitive',
 'component',
 'compound',
 'confirm',
 'contemplate',
 'coping',
 'corporations',
 'costuming',
 'counterpart',
 'crop',
 'custody',
 'cyborgs',
 'daft',
 'danced',
 'daphne',
 'darkest',
 'davids',
 'december',
 'declared',
 'defence',
 'delve',
 'demonstration',
 'dense',
 'denver',
 'devilish',
 'devious',
 'dickinson',
 'digs',
 'directorwriter',
 'download',
 'effortless',
 'electricity',
 'elliot',
 'enlightenment',
 'erratic',
 'exceedingly',
 'eyeballs',
 'fearless',
 'fenton',
 'fiennes',
 'filter',
 'fireworks',
 'flipping',
 'float',
 'foggy',
 'forgivable',
 'framework',
 'fulllength',
 'funds',
 'gamut',
 'geeks',
 'glee',
 'goo',
 'gripe',
 'hardest',
 'harmony',
 'henchman',
 'heritage',
 'hg',
 'hi',
 'hightech',
 'homework',
 'houston',
 'howards',
 'hunger',
 'imho',
 'immigrants',
 'improvised',
 'impulse',
 'inch',
 'interpret',
 'intimidating',
 'iowa',
 'jaffar',
 'jeep',
 'jock',
 'kriemhild',
 'kristofferson',
 'lassie',
 'laughoutloud',
 'lennon',
 'librarian',
 'liza',
 'locker',
 'lommel',
 'loren',
 'lowered',
 'marital',
 'martins',
 'mastroianni',
 'megan',
 'melt',
 'mischievous',
 'monstrosity',
 'monumental',
 'morse',
 'mostel',
 'muddy',
 'noah',
 'noirs',
 'nostril',
 'numbing',
 'occupation',
 'oceans',
 'onesided',
 'opus',
 'organ',
 'osullivan',
 'otoole',
 'overnight',
 'parisian',
 'partial',
 'patriotism',
 'pbs',
 'penchant',
 'penguin',
 'plotted',
 'powerfully',
 'pows',
 'practicing',
 'prehistoric',
 'prestigious',
 'prevalent',
 'prevents',
 'profits',
 'promotion',
 'puke',
 'pulse',
 'punchline',
 'quarters',
 'rainer',
 'ranting',
 'rapists',
 'rapture',
 'rarity',
 'rays',
 'recommending',
 'redeemed',
 'refuge',
 'refugee',
 'relates',
 'religions',
 'remaking',
 'renee',
 'reply',
 'restoration',
 'resurrection',
 'retreat',
 'retro',
 'rockets',
 'romano',
 'rooker',
 'rooted',
 'runtime',
 'sap',
 'scarred',
 'secluded',
 'selfabsorbed',
 'separation',
 'shattered',
 'shenanigans',
 'shootings',
 'shue',
 'silk',
 'sm',
 'soooo',
 'spoton',
 'sr',
 'staple',
 'stepfather',
 'stoic',
 'stud',
 'suite',
 'swanson',
 'sweetness',
 'sybil',
 'tease',
 'technological',
 'tensions',
 'theft',
 'therapist',
 'threats',
 'tin',
 'towel',
 'transform',
 'travelling',
 'troupe',
 'unremarkable',
 'unsatisfied',
 'untrue',
 'vertigo',
 'vic'}

Generate updated embeddings using the keras.utils.warmstart_embedding_matrix util.

# Generate the updated embedding matrix
updated_embedding = tf.keras.utils.warmstart_embedding_matrix(
    base_vocabulary=vocab_base,
    new_vocabulary=vocab_new,
    base_embeddings=embedding_weights_base,
    new_embeddings_initializer="uniform",
)
# Update the model variable
updated_embedding_variable = tf.Variable(updated_embedding)

OR

If you have an embedding matrix which you would like to initialize the new embedding matrix with, use keras.initializers.Constant as new_embeddings initializer. Copy the following block to a code cell to try this out. This would be helpful when you have a better embedding matrix initialization for new words in vocab.

# generate updated embedding matrix
new_embedding = np.random.rand(len(vocab_new), 16)
updated_embedding = tf.keras.utils.warmstart_embedding_matrix(
            base_vocabulary=vocab_base,
            new_vocabulary=vocab_new,
            base_embeddings=embedding_weights_base,
            new_embeddings_initializer=tf.keras.initializers.Constant(
                new_embedding
            )
        )
# update model variable
updated_embedding_variable = tf.Variable(updated_embedding)

Verify if the embedding matrix's shape has changed to reflect the new vocabulary.

updated_embedding_variable.shape
TensorShape([10200, 16])

Now that you have the updated embedding matrix, the next step is to update the layer weights.

text_embedding_layer_new = Embedding(
    vectorize_layer_new.vocabulary_size(), embedding_dim, name="embedding"
)
text_embedding_layer_new.build(input_shape=[None])
text_embedding_layer_new.embeddings.assign(updated_embedding)
text_input_new = tf.keras.Sequential(
    [vectorize_layer_new, text_embedding_layer_new], name="text_input_new"
)
text_input_new.summary()

# Verify the shape of updated weights
# The new weights shape should reflect the new vocabulary size
text_input_new.get_layer("embedding").embeddings.shape
Model: "text_input_new"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 text_vectorization_1 (Text  (None, 100)               0         
 Vectorization)                                                  
                                                                 
 embedding (Embedding)       (None, 100, 16)           163200    
                                                                 
=================================================================
Total params: 163200 (637.50 KB)
Trainable params: 163200 (637.50 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
TensorShape([10200, 16])

Modify the model architecture to use the new text vectorization layer.

You can also load the model from a checkpoint and update the model architecture as shown below.

warm_started_model = tf.keras.Sequential([text_input_new, classifier_head])
warm_started_model.summary()
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 text_input_new (Sequential  (None, 100, 16)           163200    
 )                                                               
                                                                 
 classifier_head (Sequentia  (None, 1)                 289       
 l)                                                              
                                                                 
=================================================================
Total params: 163489 (638.63 KB)
Trainable params: 163489 (638.63 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

You have successfully updated the model to accept a new vocabulary. The embedding layer is updated to map old vocabulary words to old embeddings and initialize embeddings for new vocabulary words to be learnt. The learned weights of the rest of the model will remain the same. The model is warm-started to continue to train from where it left off previously.

You can now verify that the remapping worked. Get the index of the vocabulary word "the" that is present both in base and new vocabulary and compare the embedding values. They should be equal.

# New vocab words
base_vocab_index = vectorize_layer("the")[0]
new_vocab_index = vectorize_layer_new("the")[0]
print(
    warm_started_model.get_layer("text_input_new").get_layer("embedding")(
        new_vocab_index
    )
    == embedding_weights_base[base_vocab_index]
)
tf.Tensor(
[ True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True], shape=(16,), dtype=bool)

Continue with warm-started training

Notice how the training is warm-started. The accuracy of first epoch is around 85%. This is close to the accuracy where the previous training ended.

model.compile(
    optimizer="adam",
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=["accuracy"],
)
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=15,
    callbacks=[tensorboard_callback],
)
Epoch 1/15
20/20 [==============================] - 4s 137ms/step - loss: 0.3677 - accuracy: 0.8413 - val_loss: 0.4162 - val_accuracy: 0.7950
Epoch 2/15
20/20 [==============================] - 1s 53ms/step - loss: 0.3535 - accuracy: 0.8479 - val_loss: 0.4086 - val_accuracy: 0.7996
Epoch 3/15
20/20 [==============================] - 1s 52ms/step - loss: 0.3411 - accuracy: 0.8536 - val_loss: 0.4019 - val_accuracy: 0.8008
Epoch 4/15
20/20 [==============================] - 1s 53ms/step - loss: 0.3295 - accuracy: 0.8590 - val_loss: 0.3963 - val_accuracy: 0.8058
Epoch 5/15
20/20 [==============================] - 1s 54ms/step - loss: 0.3188 - accuracy: 0.8637 - val_loss: 0.3915 - val_accuracy: 0.8080
Epoch 6/15
20/20 [==============================] - 1s 52ms/step - loss: 0.3087 - accuracy: 0.8686 - val_loss: 0.3875 - val_accuracy: 0.8110
Epoch 7/15
20/20 [==============================] - 1s 53ms/step - loss: 0.2992 - accuracy: 0.8741 - val_loss: 0.3842 - val_accuracy: 0.8136
Epoch 8/15
20/20 [==============================] - 1s 53ms/step - loss: 0.2903 - accuracy: 0.8781 - val_loss: 0.3816 - val_accuracy: 0.8164
Epoch 9/15
20/20 [==============================] - 1s 52ms/step - loss: 0.2819 - accuracy: 0.8819 - val_loss: 0.3796 - val_accuracy: 0.8176
Epoch 10/15
20/20 [==============================] - 1s 52ms/step - loss: 0.2739 - accuracy: 0.8858 - val_loss: 0.3781 - val_accuracy: 0.8190
Epoch 11/15
20/20 [==============================] - 1s 52ms/step - loss: 0.2663 - accuracy: 0.8899 - val_loss: 0.3771 - val_accuracy: 0.8214
Epoch 12/15
20/20 [==============================] - 1s 53ms/step - loss: 0.2591 - accuracy: 0.8934 - val_loss: 0.3766 - val_accuracy: 0.8218
Epoch 13/15
20/20 [==============================] - 1s 52ms/step - loss: 0.2522 - accuracy: 0.8965 - val_loss: 0.3764 - val_accuracy: 0.8218
Epoch 14/15
20/20 [==============================] - 1s 53ms/step - loss: 0.2456 - accuracy: 0.8992 - val_loss: 0.3767 - val_accuracy: 0.8226
Epoch 15/15
20/20 [==============================] - 1s 54ms/step - loss: 0.2392 - accuracy: 0.9024 - val_loss: 0.3774 - val_accuracy: 0.8238
<keras.src.callbacks.History at 0x7f4c8d7c9700>

Visualize warm-started training

# docs_infra: no_execute
%reload_ext tensorboard
%tensorboard --logdir logs

Next steps

In this tutorial you learned how to:

  • Train a sentiment classification model from scratch on a small vocabulary dataset.
  • Update the model architecture and warm start the embedding matrix when the vocabulary size changes.
  • Continuously improve model accuracy with expanding datasets

To learn more about embeddings check out the Word2Vec and Transformer model for language understanding tutorials.