View 在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 | 查看 TF Hub 模型 |
此 Colab 演示了如何执行以下操作:
- 从 TensorFlow Hub 加载已针对不同任务(包括 MNLI、SQuAD 和 PubMed)进行训练的 BERT 模型
- 使用匹配的预处理模型将原始文本词例化并转换为 ID
- 使用加载的模型从词例输入 ID 生成池化和序列输出
- 查看不同句子的池化输出的语义相似度
注:此 Colab 应与 GPU 运行时一起运行
设置和导入
pip install --quiet "tensorflow-text==2.8.*"
import seaborn as sns
from sklearn.metrics import pairwise
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text # Imports TF ops for preprocessing.
Configure the model
BERT_MODEL = "https://tfhub.dev/google/experts/bert/wiki_books/2" # @param {type: "string"} ["https://tfhub.dev/google/experts/bert/wiki_books/2", "https://tfhub.dev/google/experts/bert/wiki_books/mnli/2", "https://tfhub.dev/google/experts/bert/wiki_books/qnli/2", "https://tfhub.dev/google/experts/bert/wiki_books/qqp/2", "https://tfhub.dev/google/experts/bert/wiki_books/squad2/2", "https://tfhub.dev/google/experts/bert/wiki_books/sst2/2", "https://tfhub.dev/google/experts/bert/pubmed/2", "https://tfhub.dev/google/experts/bert/pubmed/squad2/2"]
# Preprocessing must match the model, but all the above use the same.
PREPROCESS_MODEL = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
句子
我们从 Wikipedia 中获取一些句子以通过模型运行
sentences = [
"Here We Go Then, You And I is a 1999 album by Norwegian pop artist Morten Abel. It was Abel's second CD as a solo artist.",
"The album went straight to number one on the Norwegian album chart, and sold to double platinum.",
"Among the singles released from the album were the songs \"Be My Lover\" and \"Hard To Stay Awake\".",
"Riccardo Zegna is an Italian jazz musician.",
"Rajko Maksimović is a composer, writer, and music pedagogue.",
"One of the most significant Serbian composers of our time, Maksimović has been and remains active in creating works for different ensembles.",
"Ceylon spinach is a common name for several plants and may refer to: Basella alba Talinum fruticosum",
"A solar eclipse occurs when the Moon passes between Earth and the Sun, thereby totally or partly obscuring the image of the Sun for a viewer on Earth.",
"A partial solar eclipse occurs in the polar regions of the Earth when the center of the Moon's shadow misses the Earth.",
]
运行模型
我们将从 TF-Hub 加载 BERT 模型,使用 TF-Hub 中的匹配预处理模型将句子词例化,然后将词例化句子馈入模型。为了让此 Colab 变得快速而简单,我们建议在 GPU 上运行。
转到 Runtime → Change runtime type 以确保选择 GPU
preprocess = hub.load(PREPROCESS_MODEL)
bert = hub.load(BERT_MODEL)
inputs = preprocess(sentences)
outputs = bert(inputs)
print("Sentences:")
print(sentences)
print("\nBERT inputs:")
print(inputs)
print("\nPooled embeddings:")
print(outputs["pooled_output"])
print("\nPer token embeddings:")
print(outputs["sequence_output"])
Sentences: ["Here We Go Then, You And I is a 1999 album by Norwegian pop artist Morten Abel. It was Abel's second CD as a solo artist.", 'The album went straight to number one on the Norwegian album chart, and sold to double platinum.', 'Among the singles released from the album were the songs "Be My Lover" and "Hard To Stay Awake".', 'Riccardo Zegna is an Italian jazz musician.', 'Rajko Maksimović is a composer, writer, and music pedagogue.', 'One of the most significant Serbian composers of our time, Maksimović has been and remains active in creating works for different ensembles.', 'Ceylon spinach is a common name for several plants and may refer to: Basella alba Talinum fruticosum', 'A solar eclipse occurs when the Moon passes between Earth and the Sun, thereby totally or partly obscuring the image of the Sun for a viewer on Earth.', "A partial solar eclipse occurs in the polar regions of the Earth when the center of the Moon's shadow misses the Earth."] BERT inputs: {'input_word_ids': <tf.Tensor: shape=(9, 128), dtype=int32, numpy= array([[ 101, 2182, 2057, ..., 0, 0, 0], [ 101, 1996, 2201, ..., 0, 0, 0], [ 101, 2426, 1996, ..., 0, 0, 0], ..., [ 101, 16447, 6714, ..., 0, 0, 0], [ 101, 1037, 5943, ..., 0, 0, 0], [ 101, 1037, 7704, ..., 0, 0, 0]], dtype=int32)>, 'input_mask': <tf.Tensor: shape=(9, 128), dtype=int32, numpy= array([[1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], ..., [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0]], dtype=int32)>, 'input_type_ids': <tf.Tensor: shape=(9, 128), dtype=int32, numpy= array([[0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], ..., [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0]], dtype=int32)>} Pooled embeddings: tf.Tensor( [[ 0.7975983 -0.4858047 0.49781665 ... -0.34488207 0.3972758 -0.20639578] [ 0.5712035 -0.41205317 0.70489097 ... -0.35185057 0.19032365 -0.40419084] [-0.6993837 0.1586686 0.06569945 ... -0.06232291 -0.81550217 -0.07923597] ... [-0.3572722 0.77089787 0.15756367 ... 0.44185576 -0.86448324 0.04504809] [ 0.9107701 0.41501644 0.5606342 ... -0.49263844 0.3964056 -0.05036103] [ 0.90502876 -0.15505227 0.72672117 ... -0.34734455 0.50526446 -0.19542967]], shape=(9, 768), dtype=float32) Per token embeddings: tf.Tensor( [[[ 1.0919763e+00 -5.3055435e-01 5.4639924e-01 ... -3.5962319e-01 4.2041004e-01 -2.0940384e-01] [ 1.0143832e+00 7.8078997e-01 8.5375911e-01 ... 5.5282390e-01 -1.1245768e+00 5.6027830e-01] [ 7.8862834e-01 7.7776447e-02 9.5150828e-01 ... -1.9075394e-01 5.9206229e-01 6.1910677e-01] ... [-3.2203096e-01 -4.2521316e-01 -1.2823755e-01 ... -3.9094931e-01 -7.9097426e-01 4.2236397e-01] [-3.1037472e-02 2.3985589e-01 -2.1994336e-01 ... -1.1440081e-01 -1.2680490e+00 -1.6136405e-01] [-4.2063668e-01 5.4972923e-01 -3.2444507e-01 ... -1.8478569e-01 -1.1342961e+00 -5.8976438e-02]] [[ 6.4930725e-01 -4.3808180e-01 8.7695575e-01 ... -3.6755425e-01 1.9267297e-01 -4.2864799e-01] [-1.1248751e+00 2.9931432e-01 1.1799647e+00 ... 4.8729539e-01 5.3400397e-01 2.2836086e-01] [-2.7057484e-01 3.2353774e-02 1.0425684e+00 ... 5.8993781e-01 1.5367906e+00 5.8425695e-01] ... [-1.4762504e+00 1.8239306e-01 5.5877924e-02 ... -1.6733217e+00 -6.7398900e-01 -7.2449714e-01] [-1.5138137e+00 5.8184761e-01 1.6141929e-01 ... -1.2640836e+00 -4.0272185e-01 -9.7197187e-01] [-4.7152787e-01 2.2817361e-01 5.2776086e-01 ... -7.5483733e-01 -9.0903133e-01 -1.6954741e-01]] [[-8.6609292e-01 1.6002062e-01 6.5794230e-02 ... -6.2403791e-02 -1.1432397e+00 -7.9402432e-02] [ 7.7118009e-01 7.0804596e-01 1.1350013e-01 ... 7.8830987e-01 -3.1438011e-01 -9.7487241e-01] [-4.4002396e-01 -3.0059844e-01 3.5479474e-01 ... 7.9736769e-02 -4.7393358e-01 -1.1001850e+00] ... [-1.0205296e+00 2.6938295e-01 -4.7310317e-01 ... -6.6319406e-01 -1.4579906e+00 -3.4665293e-01] [-9.7003269e-01 -4.5014530e-02 -5.9779799e-01 ... -3.0526215e-01 -1.2744255e+00 -2.8051612e-01] [-7.3144299e-01 1.7699258e-01 -4.6257949e-01 ... -1.6062324e-01 -1.6346085e+00 -3.2060498e-01]] ... [[-3.7375548e-01 1.0225370e+00 1.5888736e-01 ... 4.7453445e-01 -1.3108220e+00 4.5078602e-02] [-4.1589195e-01 5.0019342e-01 -4.5844358e-01 ... 4.1482633e-01 -6.2065941e-01 -7.1554971e-01] [-1.2504396e+00 5.0936830e-01 -5.7103878e-01 ... 3.5491806e-01 2.4368122e-01 -2.0577202e+00] ... [ 1.3393565e-01 1.1859145e+00 -2.2170596e-01 ... -8.1946641e-01 -1.6737353e+00 -3.9692396e-01] [-3.3662772e-01 1.6556194e+00 -3.7813133e-01 ... -9.6745455e-01 -1.4801090e+00 -8.3330792e-01] [-2.2649661e-01 1.6178432e+00 -6.7044818e-01 ... -4.9078292e-01 -1.4535757e+00 -7.1707249e-01]] [[ 1.5320230e+00 4.4165635e-01 6.3375759e-01 ... -5.3953838e-01 4.1937724e-01 -5.0403673e-02] [ 8.9377761e-01 8.9395475e-01 3.0627429e-02 ... 5.9038877e-02 -2.0649567e-01 -8.4811318e-01] [-1.8558376e-02 1.0479058e+00 -1.3329605e+00 ... -1.3869658e-01 -3.7879506e-01 -4.9068686e-01] ... [ 1.4275625e+00 1.0696868e-01 -4.0634036e-02 ... -3.1777412e-02 -4.1459864e-01 7.0036912e-01] [ 1.1286640e+00 1.4547867e-01 -6.1372513e-01 ... 4.7491822e-01 -3.9852142e-01 4.3124473e-01] [ 1.4393290e+00 1.8030715e-01 -4.2854571e-01 ... -2.5022799e-01 -1.0000539e+00 3.5985443e-01]] [[ 1.4993387e+00 -1.5631306e-01 9.2174339e-01 ... -3.6242083e-01 5.5635023e-01 -1.9797631e-01] [ 1.1110525e+00 3.6651248e-01 3.5505861e-01 ... -5.4297489e-01 1.4471433e-01 -3.1676081e-01] [ 2.4048671e-01 3.8116074e-01 -5.9182751e-01 ... 3.7410957e-01 -5.9829539e-01 -1.0166274e+00] ... [ 1.0158602e+00 5.0260085e-01 1.0736975e-01 ... -9.5642674e-01 -4.1039643e-01 -2.6760373e-01] [ 1.1848910e+00 6.5479511e-01 1.0155141e-03 ... -8.6154616e-01 -8.8041753e-02 -3.0636895e-01] [ 1.2669089e+00 4.7767794e-01 6.6289604e-03 ... -1.1585804e+00 -7.0679039e-02 -1.8678637e-01]]], shape=(9, 128, 768), dtype=float32)
语义相似度
现在,我们看一下句子的 pooled_output
嵌入向量,并比较它们在句子中的相似程度。
Helper functions
def plot_similarity(features, labels):
"""Plot a similarity matrix of the embeddings."""
cos_sim = pairwise.cosine_similarity(features)
sns.set(font_scale=1.2)
cbar_kws=dict(use_gridspec=False, location="left")
g = sns.heatmap(
cos_sim, xticklabels=labels, yticklabels=labels,
vmin=0, vmax=1, cmap="Blues", cbar_kws=cbar_kws)
g.tick_params(labelright=True, labelleft=False)
g.set_yticklabels(labels, rotation=0)
g.set_title("Semantic Textual Similarity")
plot_similarity(outputs["pooled_output"], sentences)
了解详情
- 在 TensorFlow Hub 上查找更多 BERT 模型
- 此笔记本演示了使用 BERT 的简单推断,您可以在 tensorflow.org/official_models/fine_tuning_bert 上找到有关微调 BERT 的更高级教程
- 我们仅使用一个 GPU 芯片运行模型,您可以在 tensorflow.org/tutorials/distribute/save_and_load 上详细了解如何使用 tf.distribute 加载模型