Exploring the TF-Hub CORD-19 Swivel Embeddings

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook See TF Hub model

The CORD-19 Swivel text embedding module from TF-Hub (https://tfhub.dev/tensorflow/cord-19/swivel-128d/3) was built to support researchers analyzing natural languages text related to COVID-19. These embeddings were trained on the titles, authors, abstracts, body texts, and reference titles of articles in the CORD-19 dataset.

In this colab we will:

  • Analyze semantically similar words in the embedding space
  • Train a classifier on the SciCite dataset using the CORD-19 embeddings

Setup

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow as tf

import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tqdm import trange
2022-12-14 12:15:19.997220: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:15:19.997327: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:15:19.997337: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Analyze the embeddings

Let's start off by analyzing the embedding by calculating and plotting a correlation matrix between different terms. If the embedding learned to successfully capture the meaning of different words, the embedding vectors of semantically similar words should be close together. Let's take a look at some COVID-19 related terms.

# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
  corr = np.inner(features, features)
  corr /= np.max(corr)
  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)

# Generate embeddings for some terms
queries = [
  # Related viruses
  'coronavirus', 'SARS', 'MERS',
  # Regions
  'Italy', 'Spain', 'Europe',
  # Symptoms
  'cough', 'fever', 'throat'
]

module = hub.load('https://tfhub.dev/tensorflow/cord-19/swivel-128d/3')
embeddings = module(queries)

plot_correlation(queries, embeddings)

png

We can see that the embedding successfully captured the meaning of the different terms. Each word is similar to the other words of its cluster (i.e. "coronavirus" highly correlates with "SARS" and "MERS"), while they are different from terms of other clusters (i.e. the similarity between "SARS" and "Spain" is close to 0).

Now let's see how we can use these embeddings to solve a specific task.

SciCite: Citation Intent Classification

This section shows how one can use the embedding for downstream tasks such as text classification. We'll use the SciCite dataset from TensorFlow Datasets to classify citation intents in academic papers. Given a sentence with a citation from an academic paper, classify whether the main intent of the citation is as background information, use of methods, or comparing results.

builder = tfds.builder(name='scicite')
builder.download_and_prepare()
train_data, validation_data, test_data = builder.as_dataset(
    split=('train', 'validation', 'test'),
    as_supervised=True)

Let's take a look at a few labeled examples from the training set

Training a citaton intent classifier

We'll train a classifier on the SciCite dataset using Keras. Let's build a model which use the CORD-19 embeddings with a classification layer on top.

Hyperparameters

WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 keras_layer (KerasLayer)    (None, 128)               17301632  
                                                                 
 dense (Dense)               (None, 3)                 387       
                                                                 
=================================================================
Total params: 17,302,019
Trainable params: 387
Non-trainable params: 17,301,632
_________________________________________________________________

Train and evaluate the model

Let's train and evaluate the model to see the performance on the SciCite task

EPOCHS = 35
BATCH_SIZE = 32

history = model.fit(train_data.shuffle(10000).batch(BATCH_SIZE),
                    epochs=EPOCHS,
                    validation_data=validation_data.batch(BATCH_SIZE),
                    verbose=1)
Epoch 1/35
257/257 [==============================] - 3s 5ms/step - loss: 0.8885 - accuracy: 0.6181 - val_loss: 0.7775 - val_accuracy: 0.6834
Epoch 2/35
257/257 [==============================] - 1s 4ms/step - loss: 0.6950 - accuracy: 0.7210 - val_loss: 0.6685 - val_accuracy: 0.7434
Epoch 3/35
257/257 [==============================] - 1s 4ms/step - loss: 0.6220 - accuracy: 0.7576 - val_loss: 0.6227 - val_accuracy: 0.7566
Epoch 4/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5865 - accuracy: 0.7700 - val_loss: 0.5983 - val_accuracy: 0.7544
Epoch 5/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5668 - accuracy: 0.7781 - val_loss: 0.5838 - val_accuracy: 0.7631
Epoch 6/35
257/257 [==============================] - 2s 4ms/step - loss: 0.5549 - accuracy: 0.7824 - val_loss: 0.5775 - val_accuracy: 0.7609
Epoch 7/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5452 - accuracy: 0.7857 - val_loss: 0.5696 - val_accuracy: 0.7729
Epoch 8/35
257/257 [==============================] - 2s 4ms/step - loss: 0.5383 - accuracy: 0.7891 - val_loss: 0.5636 - val_accuracy: 0.7751
Epoch 9/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5339 - accuracy: 0.7918 - val_loss: 0.5617 - val_accuracy: 0.7740
Epoch 10/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5291 - accuracy: 0.7919 - val_loss: 0.5595 - val_accuracy: 0.7697
Epoch 11/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5260 - accuracy: 0.7928 - val_loss: 0.5577 - val_accuracy: 0.7740
Epoch 12/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5231 - accuracy: 0.7944 - val_loss: 0.5566 - val_accuracy: 0.7740
Epoch 13/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5207 - accuracy: 0.7955 - val_loss: 0.5545 - val_accuracy: 0.7773
Epoch 14/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5181 - accuracy: 0.7973 - val_loss: 0.5526 - val_accuracy: 0.7817
Epoch 15/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5162 - accuracy: 0.7956 - val_loss: 0.5541 - val_accuracy: 0.7697
Epoch 16/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5145 - accuracy: 0.7984 - val_loss: 0.5497 - val_accuracy: 0.7828
Epoch 17/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5129 - accuracy: 0.7975 - val_loss: 0.5506 - val_accuracy: 0.7828
Epoch 18/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5117 - accuracy: 0.7973 - val_loss: 0.5497 - val_accuracy: 0.7795
Epoch 19/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5106 - accuracy: 0.7977 - val_loss: 0.5463 - val_accuracy: 0.7806
Epoch 20/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5089 - accuracy: 0.7996 - val_loss: 0.5481 - val_accuracy: 0.7838
Epoch 21/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5081 - accuracy: 0.7985 - val_loss: 0.5510 - val_accuracy: 0.7806
Epoch 22/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5073 - accuracy: 0.7989 - val_loss: 0.5476 - val_accuracy: 0.7817
Epoch 23/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5059 - accuracy: 0.7981 - val_loss: 0.5467 - val_accuracy: 0.7860
Epoch 24/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5055 - accuracy: 0.7995 - val_loss: 0.5464 - val_accuracy: 0.7838
Epoch 25/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5048 - accuracy: 0.8006 - val_loss: 0.5464 - val_accuracy: 0.7849
Epoch 26/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5041 - accuracy: 0.8002 - val_loss: 0.5441 - val_accuracy: 0.7828
Epoch 27/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5033 - accuracy: 0.8006 - val_loss: 0.5452 - val_accuracy: 0.7838
Epoch 28/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5024 - accuracy: 0.8003 - val_loss: 0.5435 - val_accuracy: 0.7871
Epoch 29/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5021 - accuracy: 0.8003 - val_loss: 0.5441 - val_accuracy: 0.7806
Epoch 30/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5014 - accuracy: 0.8016 - val_loss: 0.5434 - val_accuracy: 0.7871
Epoch 31/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5008 - accuracy: 0.8012 - val_loss: 0.5460 - val_accuracy: 0.7849
Epoch 32/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5007 - accuracy: 0.8027 - val_loss: 0.5444 - val_accuracy: 0.7893
Epoch 33/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5003 - accuracy: 0.8019 - val_loss: 0.5434 - val_accuracy: 0.7882
Epoch 34/35
257/257 [==============================] - 1s 4ms/step - loss: 0.5000 - accuracy: 0.8019 - val_loss: 0.5442 - val_accuracy: 0.7882
Epoch 35/35
257/257 [==============================] - 1s 4ms/step - loss: 0.4994 - accuracy: 0.8018 - val_loss: 0.5459 - val_accuracy: 0.7828
from matplotlib import pyplot as plt
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    plt.tight_layout()
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])
display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)
/tmpfs/tmp/ipykernel_15982/4094752860.py:6: MatplotlibDeprecationWarning: Auto-removal of overlapping axes is deprecated since 3.6 and will be removed two minor releases later; explicitly call ax.remove() as needed.
  ax = plt.subplot(subplot)

png

Evaluate the model

And let's see how the model performs. Two values will be returned. Loss (a number which represents our error, lower values are better), and accuracy.

results = model.evaluate(test_data.batch(512), verbose=2)

for name, value in zip(model.metrics_names, results):
  print('%s: %.3f' % (name, value))
4/4 - 0s - loss: 0.5363 - accuracy: 0.7897 - 349ms/epoch - 87ms/step
loss: 0.536
accuracy: 0.790

We can see that the loss quickly decreases while especially the accuracy rapidly increases. Let's plot some examples to check how the prediction relates to the true labels:

prediction_dataset = next(iter(test_data.batch(20)))

prediction_texts = [ex.numpy().decode('utf8') for ex in prediction_dataset[0]]
prediction_labels = [label2str(x) for x in prediction_dataset[1]]

predictions = [
    label2str(x) for x in np.argmax(model.predict(prediction_texts), axis=-1)]


pd.DataFrame({
    TEXT_FEATURE_NAME: prediction_texts,
    LABEL_NAME: prediction_labels,
    'prediction': predictions
})
1/1 [==============================] - 0s 145ms/step

We can see that for this random sample, the model predicts the correct label most of the times, indicating that it can embed scientific sentences pretty well.

What's next?

Now that you've gotten to know a bit more about the CORD-19 Swivel embeddings from TF-Hub, we encourage you to participate in the CORD-19 Kaggle competition to contribute to gaining scientific insights from COVID-19 related academic texts.