Warm-start embedding layer matrix

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial shows how to "warm-start" training using the tf.keras.utils.warmstart_embedding_matrix API for text sentiment classification when changing vocabulary.

You will begin by training a simple Keras model with a base vocabulary, and then, after updating the vocabulary, continue training the model. This is referred to as "warm-start" training, for which you'll need to remap the text-embedding matrix for the new vocabulary.

Embedding matrix

Embeddings provide a way to use an efficient, dense representation in which similar vocabulary tokens have a similar encoding. They are trainable parameters (weights learned by the model during training, in the same way a model learns weights for a dense layer). It is common to have embeddings that are 8-dimensional for small datasets, and up to 1024-dimensions when working with large datasets. A higher dimensional embedding can capture fine-grained relationships between words, but can take more data to learn.

Vocabulary

The set of unique words is referred to as the vocabulary. To build a text model you need to choose a fixed vocabulary. Typically you you build the vocabulary from the most common words in a dataset. The vocabulary allows us to represent each piece of text by a sequence of ID's that you can lookup in the embedding matrix. Vocabulary allows us to represent each piece of text by the specific words that appear in it.

Why warm-start an embedding matrix?

A model is trained with a set of embeddings that represents a given vocabulary. If the model needs to be updated or improved you can train to convergence significantly faster by reusing weights from a previous run. Using the embedding matrix from a previous run is more difficult. The problem is that any change to the vocabulary invalidates the word to id mapping.

The tf.keras.utils.warmstart_embedding_matrix solves this problem by creating an embedding matrix for a new vocabulary from an embedding martix from a base vocabulary. Where a word exists in both vocabularies the base embedding vector is copied into the correct location in the new embedding matrix. This allows you to warm-start training after any change in the size or order of the vocabulary.

Setup

pip install --pre -U "tensorflow>2.10"  # Requires 2.11
import io
import numpy as np
import os
import re
import shutil
import string
import tensorflow as tf

from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Embedding, GlobalAveragePooling1D
from tensorflow.keras.layers import TextVectorization
2022-11-16 02:21:30.143701: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-11-16 02:21:30.143807: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-11-16 02:21:30.143817: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Load the dataset

The tutorial uses the Large Movie Review Dataset. You will train a sentiment classifier model on this dataset and in the process learn embeddings from scratch. Refer to Loading text tutorial to learn more.

Download the dataset using Keras file utility and review the directories.

url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"

dataset = tf.keras.utils.get_file(
    "aclImdb_v1.tar.gz", url, untar=True, cache_dir=".", cache_subdir=""
)

dataset_dir = os.path.join(os.path.dirname(dataset), "aclImdb")
os.listdir(dataset_dir)
Downloading data from https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
84125825/84125825 [==============================] - 3s 0us/step
['imdb.vocab', 'imdbEr.txt', 'README', 'train', 'test']

The train/ directory has pos and neg folders with movie reviews labelled as positive and negative respectively. You will use reviews from pos and neg folders to train a binary classification model.

train_dir = os.path.join(dataset_dir, "train")
os.listdir(train_dir)
['unsup',
 'unsupBow.feat',
 'urls_neg.txt',
 'neg',
 'pos',
 'labeledBow.feat',
 'urls_unsup.txt',
 'urls_pos.txt']

The train directory also contains additional folders which should be removed before creating the training set.

remove_dir = os.path.join(train_dir, "unsup")
shutil.rmtree(remove_dir)

Next, create a tf.data.Dataset using tf.keras.utils.text_dataset_from_directory. You can read more about using this utility in this text classification tutorial.

Use the train directory to create the training and validation sets with a split of 20% for validation.

batch_size = 1024
seed = 123
train_ds = tf.keras.utils.text_dataset_from_directory(
    "aclImdb/train",
    batch_size=batch_size,
    validation_split=0.2,
    subset="training",
    seed=seed,
)
val_ds = tf.keras.utils.text_dataset_from_directory(
    "aclImdb/train",
    batch_size=batch_size,
    validation_split=0.2,
    subset="validation",
    seed=seed,
)
Found 25000 files belonging to 2 classes.
Using 20000 files for training.
Found 25000 files belonging to 2 classes.
Using 5000 files for validation.

Configure the dataset for performance

You can learn more about Dataset.cache and Dataset.prefetch, as well as how to cache data to disk in the data performance guide.

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

Text preprocessing

Next, define the dataset preprocessing steps required for your sentiment classification model. Initialize a layers.TextVectorization layer with the desired parameters to vectorize movie reviews. You can learn more about using this layer in the Text Classification tutorial.

# Create a custom standardization function to strip HTML break tags '<br />'.
def custom_standardization(input_data):
    lowercase = tf.strings.lower(input_data)
    stripped_html = tf.strings.regex_replace(lowercase, "<br />", " ")
    return tf.strings.regex_replace(
        stripped_html, "[%s]" % re.escape(string.punctuation), ""
    )


# Vocabulary size and number of words in a sequence.
vocab_size = 10000
sequence_length = 100

# Use the text vectorization layer to normalize, split, and map strings to
# integers. Note that the layer uses the custom standardization defined above.
# Set maximum_sequence length as all samples are not of the same length.
vectorize_layer = TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size,
    output_mode="int",
    output_sequence_length=sequence_length,
)

# Make a text-only dataset (no labels) and call `Dataset.adapt` to build the
# vocabulary.
text_ds = train_ds.map(lambda x, y: x)
vectorize_layer.adapt(text_ds)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089

Create a classification model

Use the Keras Sequential API to define the sentiment classification model.

embedding_dim = 16
text_model_input = tf.keras.layers.Input(dtype=tf.string, shape=(1,))
text_embedding = Embedding(vocab_size, embedding_dim, name="embedding")
text_input = tf.keras.Sequential(
    [vectorize_layer, text_embedding], name="text_input"
)
classifier_head = tf.keras.Sequential(
    [GlobalAveragePooling1D(), Dense(16, activation="relu"), Dense(1)],
    name="classifier_head",
)

model = tf.keras.Sequential([text_input, classifier_head])

Compile and train the model

You will use TensorBoard to visualize metrics including loss and accuracy. Create a tf.keras.callbacks.TensorBoard.

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="logs")

Compile and train the model using the Adam optimizer and BinaryCrossentropy loss.

model.compile(
    optimizer="adam",
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=["accuracy"],
)
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=15,
    callbacks=[tensorboard_callback],
)
Epoch 1/15
20/20 [==============================] - 8s 211ms/step - loss: 0.6925 - accuracy: 0.5028 - val_loss: 0.6910 - val_accuracy: 0.4886
Epoch 2/15
20/20 [==============================] - 1s 53ms/step - loss: 0.6883 - accuracy: 0.5028 - val_loss: 0.6847 - val_accuracy: 0.4886
Epoch 3/15
20/20 [==============================] - 1s 51ms/step - loss: 0.6799 - accuracy: 0.5028 - val_loss: 0.6742 - val_accuracy: 0.4886
Epoch 4/15
20/20 [==============================] - 1s 51ms/step - loss: 0.6661 - accuracy: 0.5028 - val_loss: 0.6576 - val_accuracy: 0.4886
Epoch 5/15
20/20 [==============================] - 1s 51ms/step - loss: 0.6450 - accuracy: 0.5031 - val_loss: 0.6348 - val_accuracy: 0.4910
Epoch 6/15
20/20 [==============================] - 1s 52ms/step - loss: 0.6178 - accuracy: 0.5389 - val_loss: 0.6075 - val_accuracy: 0.5584
Epoch 7/15
20/20 [==============================] - 1s 52ms/step - loss: 0.5857 - accuracy: 0.6189 - val_loss: 0.5773 - val_accuracy: 0.6292
Epoch 8/15
20/20 [==============================] - 1s 53ms/step - loss: 0.5508 - accuracy: 0.6900 - val_loss: 0.5466 - val_accuracy: 0.6768
Epoch 9/15
20/20 [==============================] - 1s 52ms/step - loss: 0.5156 - accuracy: 0.7380 - val_loss: 0.5176 - val_accuracy: 0.7158
Epoch 10/15
20/20 [==============================] - 1s 52ms/step - loss: 0.4822 - accuracy: 0.7685 - val_loss: 0.4916 - val_accuracy: 0.7414
Epoch 11/15
20/20 [==============================] - 1s 51ms/step - loss: 0.4518 - accuracy: 0.7919 - val_loss: 0.4694 - val_accuracy: 0.7568
Epoch 12/15
20/20 [==============================] - 1s 51ms/step - loss: 0.4249 - accuracy: 0.8098 - val_loss: 0.4508 - val_accuracy: 0.7668
Epoch 13/15
20/20 [==============================] - 1s 52ms/step - loss: 0.4011 - accuracy: 0.8219 - val_loss: 0.4354 - val_accuracy: 0.7772
Epoch 14/15
20/20 [==============================] - 1s 52ms/step - loss: 0.3802 - accuracy: 0.8357 - val_loss: 0.4227 - val_accuracy: 0.7866
Epoch 15/15
20/20 [==============================] - 1s 52ms/step - loss: 0.3617 - accuracy: 0.8443 - val_loss: 0.4123 - val_accuracy: 0.7942
<keras.callbacks.History at 0x7ff6573bc6d0>

With this approach the model reaches a validation accuracy of around 85%

You can look into the model summary to learn more about each layer of the model.

model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 text_input (Sequential)     (None, 100, 16)           160000    
                                                                 
 classifier_head (Sequential  (None, 1)                289       
 )                                                               
                                                                 
=================================================================
Total params: 160,289
Trainable params: 160,289
Non-trainable params: 0
_________________________________________________________________

Visualize the model metrics in TensorBoard.

# docs_infra: no_execute
%load_ext tensorboard
%tensorboard --logdir logs

Vocabulary remapping

Now you're going to update the vocabulary and continue with warm-started training.

First, get the base vocabulary and embedding matrix.

embedding_weights_base = (
    model.get_layer("text_input").get_layer("embedding").get_weights()[0]
)
vocab_base = vectorize_layer.get_vocabulary()

Define a new vectorization layer to generate a new bigger vocabulary

# Vocabulary size and number of words in a sequence.
vocab_size_new = 10200
sequence_length = 100

vectorize_layer_new = TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size_new,
    output_mode="int",
    output_sequence_length=sequence_length,
)

# Make a text-only dataset (no labels) and call adapt to build the vocabulary.
text_ds = train_ds.map(lambda x, y: x)
vectorize_layer_new.adapt(text_ds)

# Get the new vocabulary
vocab_new = vectorize_layer_new.get_vocabulary()
# View the new vocabulary tokens that weren't in `vocab_base`
set(vocab_base) ^ set(vocab_new)
{'bullying',
 'bumps',
 'canvas',
 'carole',
 'chains',
 'chairman',
 'checks',
 'coarse',
 'competitive',
 'component',
 'compound',
 'confirm',
 'contemplate',
 'coping',
 'corporations',
 'costuming',
 'counterpart',
 'crop',
 'custody',
 'cyborgs',
 'daft',
 'danced',
 'daphne',
 'darkest',
 'davids',
 'december',
 'declared',
 'defence',
 'delve',
 'demonstration',
 'dense',
 'denver',
 'devilish',
 'devious',
 'dickinson',
 'digs',
 'directorwriter',
 'download',
 'effortless',
 'electricity',
 'elliot',
 'enlightenment',
 'erratic',
 'exceedingly',
 'eyeballs',
 'fearless',
 'fenton',
 'fiennes',
 'filter',
 'fireworks',
 'flipping',
 'float',
 'foggy',
 'forgivable',
 'framework',
 'fulllength',
 'funds',
 'gamut',
 'geeks',
 'glee',
 'goo',
 'gripe',
 'hardest',
 'harmony',
 'henchman',
 'heritage',
 'hg',
 'hi',
 'hightech',
 'homework',
 'houston',
 'howards',
 'hunger',
 'imho',
 'immigrants',
 'improvised',
 'impulse',
 'inch',
 'interpret',
 'intimidating',
 'iowa',
 'jaffar',
 'jeep',
 'jock',
 'kriemhild',
 'kristofferson',
 'lassie',
 'laughoutloud',
 'lennon',
 'librarian',
 'liza',
 'locker',
 'lommel',
 'loren',
 'lowered',
 'marital',
 'martins',
 'mastroianni',
 'megan',
 'melt',
 'mischievous',
 'monstrosity',
 'monumental',
 'morse',
 'mostel',
 'muddy',
 'noah',
 'noirs',
 'nostril',
 'numbing',
 'occupation',
 'oceans',
 'onesided',
 'opus',
 'organ',
 'osullivan',
 'otoole',
 'overnight',
 'parisian',
 'partial',
 'patriotism',
 'pbs',
 'penchant',
 'penguin',
 'plotted',
 'powerfully',
 'pows',
 'practicing',
 'prehistoric',
 'prestigious',
 'prevalent',
 'prevents',
 'profits',
 'promotion',
 'puke',
 'pulse',
 'punchline',
 'quarters',
 'rainer',
 'ranting',
 'rapists',
 'rapture',
 'rarity',
 'rays',
 'recommending',
 'redeemed',
 'refuge',
 'refugee',
 'relates',
 'religions',
 'remaking',
 'renee',
 'reply',
 'restoration',
 'resurrection',
 'retreat',
 'retro',
 'rockets',
 'romano',
 'rooker',
 'rooted',
 'runtime',
 'sap',
 'scarred',
 'secluded',
 'selfabsorbed',
 'separation',
 'shattered',
 'shenanigans',
 'shootings',
 'shue',
 'silk',
 'sm',
 'soooo',
 'spoton',
 'sr',
 'staple',
 'stepfather',
 'stoic',
 'stud',
 'suite',
 'swanson',
 'sweetness',
 'sybil',
 'tease',
 'technological',
 'tensions',
 'theft',
 'therapist',
 'threats',
 'tin',
 'towel',
 'transform',
 'travelling',
 'troupe',
 'unremarkable',
 'unsatisfied',
 'untrue',
 'vertigo',
 'vic'}

Generate updated embeddings using the keras.utils.warmstart_embedding_matrix util.

# Generate the updated embedding matrix
updated_embedding = tf.keras.utils.warmstart_embedding_matrix(
    base_vocabulary=vocab_base,
    new_vocabulary=vocab_new,
    base_embeddings=embedding_weights_base,
    new_embeddings_initializer="uniform",
)
# Update the model variable
updated_embedding_variable = tf.Variable(updated_embedding)

OR

If you have an embedding matrix which you would like to initialize the new embedding matrix with, use keras.initializers.Constant as new_embeddings initializer. Copy the following block to a code cell to try this out. This would be helpful when you have a better embedding matrix initialization for new words in vocab.

# generate updated embedding matrix
new_embedding = np.random.rand(len(vocab_new), 16)
updated_embedding = tf.keras.utils.warmstart_embedding_matrix(
            base_vocabulary=vocab_base,
            new_vocabulary=vocab_new,
            base_embeddings=embedding_weights_base,
            new_embeddings_initializer=tf.keras.initializers.Constant(
                new_embedding
            )
        )
# update model variable
updated_embedding_variable = tf.Variable(updated_embedding)

Verify if the embedding matrix's shape has changed to reflect the new vocabulary.

updated_embedding_variable.shape
TensorShape([10200, 16])

Now that you have the updated embedding matrix, the next step is to update the layer weights.

text_embedding_layer_new = Embedding(
    vectorize_layer_new.vocabulary_size(), embedding_dim, name="embedding"
)
text_embedding_layer_new.build(input_shape=[None])
text_embedding_layer_new.embeddings.assign(updated_embedding)
text_input_new = tf.keras.Sequential(
    [vectorize_layer_new, text_embedding_layer_new], name="text_input_new"
)
text_input_new.summary()

# Verify the shape of updated weights
# The new weights shape should reflect the new vocabulary size
text_input_new.get_layer("embedding").get_weights()[0].shape
Model: "text_input_new"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 text_vectorization_1 (TextV  (None, 100)              0         
 ectorization)                                                   
                                                                 
 embedding (Embedding)       (None, 100, 16)           163200    
                                                                 
=================================================================
Total params: 163,200
Trainable params: 163,200
Non-trainable params: 0
_________________________________________________________________
(10200, 16)

Modify the model architecture to use the new text vectorization layer.

You can also load the model from a checkpoint and update the model architecture as shown below.

warm_started_model = tf.keras.Sequential([text_input_new, classifier_head])
warm_started_model.summary()
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 text_input_new (Sequential)  (None, 100, 16)          163200    
                                                                 
 classifier_head (Sequential  (None, 1)                289       
 )                                                               
                                                                 
=================================================================
Total params: 163,489
Trainable params: 163,489
Non-trainable params: 0
_________________________________________________________________

You have successfully updated the model to accept a new vocabulary. The embedding layer is updated to map old vocabulary words to old embeddings and initialize embeddings for new vocabulary words to be learnt. The learned weights of the rest of the model will remain the same. The model is warm-started to continue to train from where it left off previously.

You can now verify that the remapping worked. Get index of the vocabulary word "the" that is present both in base and new vocabulary and compare the embedding values. They should be equal.

# New vocab words
base_vocab_index = vectorize_layer("the")[0]
new_vocab_index = vectorize_layer_new("the")[0]
print(
    warm_started_model.get_layer("text_input_new").get_layer("embedding")(
        new_vocab_index
    )
    == embedding_weights_base[base_vocab_index]
)
tf.Tensor(
[ True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True], shape=(16,), dtype=bool)

Continue with warm-started training

Notice how the training is warm-started. The accuracy of first epoch is around 85%. Close to the accuracy where the previous traning ended.

model.compile(
    optimizer="adam",
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=["accuracy"],
)
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=15,
    callbacks=[tensorboard_callback],
)
Epoch 1/15
20/20 [==============================] - 4s 149ms/step - loss: 0.3470 - accuracy: 0.8503 - val_loss: 0.4043 - val_accuracy: 0.8002
Epoch 2/15
20/20 [==============================] - 1s 52ms/step - loss: 0.3328 - accuracy: 0.8582 - val_loss: 0.3978 - val_accuracy: 0.8040
Epoch 3/15
20/20 [==============================] - 1s 51ms/step - loss: 0.3206 - accuracy: 0.8629 - val_loss: 0.3923 - val_accuracy: 0.8062
Epoch 4/15
20/20 [==============================] - 1s 51ms/step - loss: 0.3093 - accuracy: 0.8685 - val_loss: 0.3877 - val_accuracy: 0.8100
Epoch 5/15
20/20 [==============================] - 1s 52ms/step - loss: 0.2988 - accuracy: 0.8738 - val_loss: 0.3841 - val_accuracy: 0.8122
Epoch 6/15
20/20 [==============================] - 1s 52ms/step - loss: 0.2890 - accuracy: 0.8784 - val_loss: 0.3813 - val_accuracy: 0.8160
Epoch 7/15
20/20 [==============================] - 1s 51ms/step - loss: 0.2798 - accuracy: 0.8830 - val_loss: 0.3792 - val_accuracy: 0.8182
Epoch 8/15
20/20 [==============================] - 1s 52ms/step - loss: 0.2712 - accuracy: 0.8871 - val_loss: 0.3778 - val_accuracy: 0.8198
Epoch 9/15
20/20 [==============================] - 1s 51ms/step - loss: 0.2630 - accuracy: 0.8905 - val_loss: 0.3769 - val_accuracy: 0.8206
Epoch 10/15
20/20 [==============================] - 1s 52ms/step - loss: 0.2552 - accuracy: 0.8949 - val_loss: 0.3766 - val_accuracy: 0.8226
Epoch 11/15
20/20 [==============================] - 1s 51ms/step - loss: 0.2478 - accuracy: 0.8982 - val_loss: 0.3768 - val_accuracy: 0.8234
Epoch 12/15
20/20 [==============================] - 1s 52ms/step - loss: 0.2408 - accuracy: 0.9007 - val_loss: 0.3775 - val_accuracy: 0.8244
Epoch 13/15
20/20 [==============================] - 1s 51ms/step - loss: 0.2340 - accuracy: 0.9042 - val_loss: 0.3786 - val_accuracy: 0.8232
Epoch 14/15
20/20 [==============================] - 1s 51ms/step - loss: 0.2276 - accuracy: 0.9077 - val_loss: 0.3801 - val_accuracy: 0.8228
Epoch 15/15
20/20 [==============================] - 1s 52ms/step - loss: 0.2214 - accuracy: 0.9110 - val_loss: 0.3820 - val_accuracy: 0.8236
<keras.callbacks.History at 0x7ff6572b64c0>

Visualize warm-started training

# docs_infra: no_execute
%reload_ext tensorboard
%tensorboard --logdir logs

Next steps

In this tutorial you learned how to:

  • Train a sentiment classification model from scratch on a small vocabulary dataset.
  • Update the model architecture and warm start the embedding matrix when the vocabulary size changes.
  • Continuously improve model accuracy with expanding datasets

To learn more about embeddings check out the Word2Vec and Transformer model for language understanding tutorials.