在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 Github 上查看源代码 | {img1下载笔记本 |
word2vec 不是单一算法,而是一系列模型架构和优化,可用于从大型数据集中学习单词嵌入向量。通过 word2vec 学习到的嵌入向量已被证明在各种下游自然语言处理任务上取得了成功。
注:本教程基于 Efficient estimation of word representations in vector space 和 Distributed representations of words and phrases and their compositionality。本教程不是上述论文的精确实现,而旨在阐明关键思想。
上述论文提出了两种学习单词表示的方法:
- 连续词袋模型:根据周围的上下文单词预测中间单词。上下文由当前(中间)单词前后的几个单词组成。这种架构被称为词袋模型,因为上下文中的单词顺序并不重要。
- 连续跳字模型:预测同一句子中当前单词前后一定范围内的单词。下面给出了一个工作示例。
您将在本教程中使用跳字方式。首先,您将使用一个句子来探索跳字和其他概念。接下来,您将在一个小型数据集上训练自己的 word2vec 模型。本教程还包含用于导出经过训练的嵌入向量并在 TensorFlow Embedding Projector 中可视化它们的代码。
跳字和负采样
词袋模型能够在给定相邻上下文的情况下预测单词,而跳字模型能够在给定单词本身的情况下预测单词的上下文(或邻居)。该模型在跳字上训练,它是允许跳过词例的 n 元语法(请参阅下图的示例)。一个单词的上下文可以通过一组 (target_word, context_word)
的跳字对来表示,其中 context_word
出现在 target_word
的相邻上下文中。
考虑以下由八个单词组成的句子:
The wide road shimmered in the hot sun.
这句话的 8 个单词中,每一个单词的上下文单词由一个窗口大小定义。窗口大小决定了 target_word
可以被视为 context word
的单词跨度。下面是基于不同窗口大小的目标词的跳字表。
注:对于本教程,n
的窗口大小表示每边有 n 个单词,每个单词的总窗口跨度为 2*n+1 个单词。
跳字模型的训练目标是在给定目标词的情况下最大化预测上下文词的概率。对于单词序列 w1,w2,... wT,目标可写为平均对数概率
其中,c
是训练上下文的大小。基本的跳字公式使用 Softmax 函数定义该概率。
其中,v 和 v' 是单词的目标和上下文向量表示,W 是词汇量。
计算这个公式的分母涉及对整个词汇表执行完整的 Softmax,其通常是很大的 (105-107) 项。
噪声对比估计 (NCE) 损失函数是完整 Softmax 的有效近似。为了学习单词嵌入向量而不是对单词分布进行建模,NCE 损失可以简化为使用负采样。
目标单词的简化负采样目标是将上下文单词与从单词的噪声分布 Pn(w) 中抽取的 num_ns
负样本区分开来。更准确地说,对于一个跳字对,词汇表上的完整 Softmax 的有效近似是将目标单词的损失作为上下文单词和 num_ns
负样本之间的分类问题。
负样本定义为 (target_word, context_word)
对,这样 context_word
就不会出现在 target_word
的 window_size
邻域中。对于例句,这些是一些潜在的负样本(当 window_size
为 2
时)。
(hot, shimmered)
(wide, hot)
(wide, sun)
在下一部分中,您将为单个句子生成跳字和负样本。您还将在本教程后面学习二次采样技术并为正负训练样本训练分类模型。
设置
import io
import re
import string
import tqdm
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
# Load the TensorBoard notebook extension
%load_ext tensorboard
SEED = 42
AUTOTUNE = tf.data.AUTOTUNE
向量化一个例句
请考虑以下句子:
The wide road shimmered in the hot sun.
对句子进行分词:
sentence = "The wide road shimmered in the hot sun"
tokens = list(sentence.lower().split())
print(len(tokens))
创建一个词汇表来保存从词例到整数索引的映射:
vocab, index = {}, 1 # start indexing from 1
vocab['<pad>'] = 0 # add a padding token
for token in tokens:
if token not in vocab:
vocab[token] = index
index += 1
vocab_size = len(vocab)
print(vocab)
创建一个反向词汇表来保存从整数索引到词例的映射:
inverse_vocab = {index: token for token, index in vocab.items()}
print(inverse_vocab)
向量化您的句子:
example_sequence = [vocab[word] for word in tokens]
print(example_sequence)
从一个句子生成跳字
tf.keras.preprocessing.sequence
模块提供了有用的函数来简化 word2vec 的数据准备。您可以使用 tf.keras.preprocessing.sequence.skipgrams
从 [0, vocab_size)
范围内的词例中使用给定的 window_size
从 example_sequence
生成跳字对。
注:negative_samples
在这里设置为 0
,因为批处理此函数生成的负样本需要一些代码。在下一部分中,您将使用另一个函数来执行负采样。
window_size = 2
positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
example_sequence,
vocabulary_size=vocab_size,
window_size=window_size,
negative_samples=0)
print(len(positive_skip_grams))
打印几个正跳字:
for target, context in positive_skip_grams[:5]:
print(f"({target}, {context}): ({inverse_vocab[target]}, {inverse_vocab[context]})")
对某个跳字进行负采样
skipgrams
函数通过在给定的窗口跨度上滑动来返回所有正的跳字对。要生成额外的跳字对作为训练的负样本,您需要从词汇表中随机抽取单词。使用 tf.random.log_uniform_candidate_sampler
函数在窗口中对给定目标单词采样 num_ns
个负样本。您可以在一个跳字的目标单词上调用该函数,并将上下文单词作为 true 类传递,以将其排除在采样之外。
要点:[5, 20]
范围内的 num_ns
(每个正上下文单词的负样本数)被证明最适合较小的数据集,而 [2, 5]
范围内的 num_ns
足以满足较大的数据集。
# Get target and context words for one positive skip-gram.
target_word, context_word = positive_skip_grams[0]
# Set the number of negative samples per positive context.
num_ns = 4
context_class = tf.reshape(tf.constant(context_word, dtype="int64"), (1, 1))
negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
true_classes=context_class, # class that should be sampled as 'positive'
num_true=1, # each positive skip-gram has 1 positive context class
num_sampled=num_ns, # number of negative context words to sample
unique=True, # all the negative samples should be unique
range_max=vocab_size, # pick index of the samples from [0, vocab_size]
seed=SEED, # seed for reproducibility
name="negative_sampling" # name of this operation
)
print(negative_sampling_candidates)
print([inverse_vocab[index.numpy()] for index in negative_sampling_candidates])
构造一个训练样本
对于给定的正 (target_word, context_word)
跳字,您现在还有 num_ns
个未出现在 target_word
的窗口大小邻域中的负采样上下文单词。将 1
个正 context_word
和 num_ns
个负上下文单词批处理到一个张量中。这会为每个目标单词生成一组正跳字(标记为 1
)和负样本(标记为0
)。
# Add a dimension so you can use concatenation (in the next step).
negative_sampling_candidates = tf.expand_dims(negative_sampling_candidates, 1)
# Concatenate a positive context word with negative sampled words.
context = tf.concat([context_class, negative_sampling_candidates], 0)
# Label the first context word as `1` (positive) followed by `num_ns` `0`s (negative).
label = tf.constant([1] + [0]*num_ns, dtype="int64")
# Reshape the target to shape `(1,)` and context and label to `(num_ns+1,)`.
target = tf.squeeze(target_word)
context = tf.squeeze(context)
label = tf.squeeze(label)
从上面的跳字样本中查看目标单词的上下文和相应的标签:
print(f"target_index : {target}")
print(f"target_word : {inverse_vocab[target_word]}")
print(f"context_indices : {context}")
print(f"context_words : {[inverse_vocab[c.numpy()] for c in context]}")
print(f"label : {label}")
(target, context, label)
张量的元组构成了一个训练样本,用于训练您的跳字负采样 word2vec 模型。请注意,目标的形状为 (1,)
,而上下文和标签的形状为 (1+num_ns,)
print("target :", target)
print("context :", context)
print("label :", label)
总结
此图总结了从句子生成训练样本的过程:
请注意,temperature
和 code
这两个单词不是输入句子的一部分。它们与上图中使用的某些其他索引一样属于词汇表。
将所有步骤编译为一个函数
跳字采样表
大的数据集意味着更大的词汇表,其中包含更多更频繁出现的单词,例如停用词。从经常出现的单词(例如 the
、is
、on
)中采样获得的训练样本不会为模型添加太多可供学习的有用信息。 Mikolov 等人建议对常用词进行二次采样,以此作为提高嵌入向量质量的有用做法。
tf.keras.preprocessing.sequence.skipgrams
函数接受一个采样表参数来编码对任何词例进行采样的概率。您可以使用 tf.keras.preprocessing.sequence.make_sampling_table
生成基于词频等级的概率采样表并将其传递给 skipgrams
函数。检查 vocab_size
为 10 的采样概率。
sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(size=10)
print(sampling_table)
sampling_table[i]
表示对数据集中第 i 个最常见单词进行采样的概率。该函数假设采样的单词频率符合 Zipf 分布。
要点:tf.random.log_uniform_candidate_sampler
已经假设词汇频率遵循对数均匀 (Zipf) 分布。使用这些分布加权采样还有助于使用更简单的损失函数来近似噪声对比估计 (NCE) 损失,以训练负采样目标。
生成训练数据
将上述所有步骤编译成一个函数,该函数可以在从任何文本数据集获得的向量化句子列表上调用。请注意,采样表在对跳字单词对进行采样之前构建。您将在后续部分中使用此函数。
# Generates skip-gram pairs with negative sampling for a list of sequences
# (int-encoded sentences) based on window size, number of negative samples
# and vocabulary size.
def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
# Elements of each training example are appended to these lists.
targets, contexts, labels = [], [], []
# Build the sampling table for `vocab_size` tokens.
sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)
# Iterate over all sequences (sentences) in the dataset.
for sequence in tqdm.tqdm(sequences):
# Generate positive skip-gram pairs for a sequence (sentence).
positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
sequence,
vocabulary_size=vocab_size,
sampling_table=sampling_table,
window_size=window_size,
negative_samples=0)
# Iterate over each positive skip-gram pair to produce training examples
# with a positive context word and negative samples.
for target_word, context_word in positive_skip_grams:
context_class = tf.expand_dims(
tf.constant([context_word], dtype="int64"), 1)
negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
true_classes=context_class,
num_true=1,
num_sampled=num_ns,
unique=True,
range_max=vocab_size,
seed=seed,
name="negative_sampling")
# Build context and label vectors (for one target word)
negative_sampling_candidates = tf.expand_dims(
negative_sampling_candidates, 1)
context = tf.concat([context_class, negative_sampling_candidates], 0)
label = tf.constant([1] + [0]*num_ns, dtype="int64")
# Append each element from the training example to global lists.
targets.append(target_word)
contexts.append(context)
labels.append(label)
return targets, contexts, labels
为 word2vec 准备训练数据
了解如何使用一个基于跳字负采样的 word2vec 模型的句子后,您可以继续从更大的句子列表中生成训练样本!
下载文本语料库
您将在本教程中使用莎士比亚作品的文本文件。要在您自己的数据上运行此代码,请更改以下行。
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
从文件中读取文本并打印前几行:
with open(path_to_file) as f:
lines = f.read().splitlines()
for line in lines[:20]:
print(line)
使用非空行构造一个 tf.data.TextLineDataset
对象以进行后续步骤:
text_ds = tf.data.TextLineDataset(path_to_file).filter(lambda x: tf.cast(tf.strings.length(x), bool))
向量化语料库中的句子
您可以使用 TextVectorization
层对语料库中的句子进行向量化。在此文本分类教程中了解有关使用该层的更多信息。请注意,从上面的前几个句子可以看出,文本需要大小写一致,并且需要去除标点符号。为此,请定义一个可在 TextVectorization 层中使用的 custom_standardization
函数。
# Now, create a custom standardization function to lowercase the text and
# remove punctuation.
def custom_standardization(input_data):
lowercase = tf.strings.lower(input_data)
return tf.strings.regex_replace(lowercase,
'[%s]' % re.escape(string.punctuation), '')
# Define the vocabulary size and the number of words in a sequence.
vocab_size = 4096
sequence_length = 10
# Use the `TextVectorization` layer to normalize, split, and map strings to
# integers. Set the `output_sequence_length` length to pad all samples to the
# same length.
vectorize_layer = layers.TextVectorization(
standardize=custom_standardization,
max_tokens=vocab_size,
output_mode='int',
output_sequence_length=sequence_length)
在文本数据集上调用 TextVectorization.adapt
以创建词汇表。
vectorize_layer.adapt(text_ds.batch(1024))
当层的状态适合表示文本语料库后,就可以使用 TextVectorization.get_vocabulary
访问词汇表。此函数返回按频率排序(降序)的所有词汇词例的列表。
# Save the created vocabulary for reference.
inverse_vocab = vectorize_layer.get_vocabulary()
print(inverse_vocab[:20])
现在可以使用 vectorize_layer
为 text_ds
(tf.data.Dataset
) 中的每个元素生成向量。应用 Dataset.batch
、Dataset.prefetch
、Dataset.map
和 Dataset.unbatch
。
# Vectorize the data in text_ds.
text_vector_ds = text_ds.batch(1024).prefetch(AUTOTUNE).map(vectorize_layer).unbatch()
从数据集中获取序列
现在,您拥有了一个包含整数编码句子的 tf.data.Dataset
。要准备用于训练 word2vec 模型的数据集,请将数据集展平为句子向量序列列表。此步骤是必需的,因为您将遍历数据集中的每个句子以生成正负样本。
注:由于之前定义的 generate_training_data()
使用非 TensorFlow Python/NumPy 函数,因此您也可以使用带有 tf.data.Dataset.map
的 tf.py_function
或 tf.numpy_function
。
sequences = list(text_vector_ds.as_numpy_iterator())
print(len(sequences))
检查来自 sequences
的一些样本:
for seq in sequences[:5]:
print(f"{seq} => {[inverse_vocab[i] for i in seq]}")
从序列生成训练样本
sequences
现在是整型编码句子的列表。只需调用前面定义的 generate_training_data
函数即可为 word2vec 模型生成训练样本。总结一下,该函数会迭代每个序列中的每个单词以收集正负上下文单词。目标、上下文和标签的长度应该相同,代表训练样本的总数。
targets, contexts, labels = generate_training_data(
sequences=sequences,
window_size=2,
num_ns=4,
vocab_size=vocab_size,
seed=SEED)
targets = np.array(targets)
contexts = np.array(contexts)[:,:,0]
labels = np.array(labels)
print('\n')
print(f"targets.shape: {targets.shape}")
print(f"contexts.shape: {contexts.shape}")
print(f"labels.shape: {labels.shape}")
配置数据集以提高性能
要对可能大量的训练样本执行有效的批处理,请使用 tf.data.Dataset
API。在这一步之后,您将拥有一个 (target_word, context_word), (label)
元素的 tf.data.Dataset
对象来训练您的 word2vec 模型!
BATCH_SIZE = 1024
BUFFER_SIZE = 10000
dataset = tf.data.Dataset.from_tensor_slices(((targets, contexts), labels))
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
print(dataset)
应用 Dataset.cache
和 Dataset.prefetch
来提高性能:
dataset = dataset.cache().prefetch(buffer_size=AUTOTUNE)
print(dataset)
模型和训练
word2vec 模型可以实现为分类器,以区分来自跳字的 true 上下文单词和通过负采样获得的 false 上下文词。您可以在目标单词和上下文单词的嵌入向量之间执行点积乘法,以获得标签的预测,并根据数据集中的 true 标签计算损失函数。
子类化的 word2vec 模型
使用 Keras Subclassing API 定义具有以下层的 word2vec 模型:
target_embedding
:tf.keras.layers.Embedding
层,当单词作为目标单词出现时,它会查找单词的嵌入向量。该层的参数数量为(vocab_size * embedding_dim)
。context_embedding
:另一个tf.keras.layers.Embedding
层,当单词作为上下文单词出现时,它会查找单词的嵌入向量。该层的参数数量与target_embedding
相同,即(vocab_size * embedding_dim)
。dots
:tf.keras.layers.Dot
层,用于计算训练对中目标和上下文嵌入向量的点积。flatten
:tf.keras.layers.Flatten
层,用于将dots
层的结果展平为 logits。
使用子类化模型,您可以定义接受 (target, context)
对的 call()
函数,然后可以将其传递到相应的嵌入向量层。改造 context_embedding
以执行与 target_embedding
的点积并返回展平结果。
要点:target_embedding
和 context_embedding
层也可以共享。您还可以使用两个嵌入向量的串联作为最终的 word2vec 嵌入向量。
class Word2Vec(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim):
super(Word2Vec, self).__init__()
self.target_embedding = layers.Embedding(vocab_size,
embedding_dim,
input_length=1,
name="w2v_embedding")
self.context_embedding = layers.Embedding(vocab_size,
embedding_dim,
input_length=num_ns+1)
def call(self, pair):
target, context = pair
# target: (batch, dummy?) # The dummy axis doesn't exist in TF2.7+
# context: (batch, context)
if len(target.shape) == 2:
target = tf.squeeze(target, axis=1)
# target: (batch,)
word_emb = self.target_embedding(target)
# word_emb: (batch, embed)
context_emb = self.context_embedding(context)
# context_emb: (batch, context, embed)
dots = tf.einsum('be,bce->bc', word_emb, context_emb)
# dots: (batch, context)
return dots
定义损失函数并编译模型
为简单起见,您可以使用 tf.keras.losses.CategoricalCrossEntropy
作为负采样损失的替代方案。如果您想编写自己的自定义损失函数,也可以这样做:
def custom_loss(x_logit, y_true):
return tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=y_true)
是时候构建您的模型了!使用嵌入向量维度 128(您可以尝试不同的值)实例化您的 word2vec 类。使用 tf.keras.optimizers.Adam
优化器编译模型。
embedding_dim = 128
word2vec = Word2Vec(vocab_size, embedding_dim)
word2vec.compile(optimizer='adam',
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
另外,定义一个回调来记录 TensorBoard 的训练统计信息:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="logs")
在 dataset
上对模型进行几个周期的训练:
word2vec.fit(dataset, epochs=20, callbacks=[tensorboard_callback])
TensorBoard 现在会显示 word2vec 模型的准确率和损失:
#docs_infra: no_execute
%tensorboard --logdir logs
嵌入向量查找和分析
使用 Model.get_layer
和 Layer.get_weights
从模型中获取权重。TextVectorization.get_vocabulary
函数提供词汇来构建元数据文件,每行一个词例。
weights = word2vec.get_layer('w2v_embedding').get_weights()[0]
vocab = vectorize_layer.get_vocabulary()
创建并保存向量和元数据文件:
out_v = io.open('vectors.tsv', 'w', encoding='utf-8')
out_m = io.open('metadata.tsv', 'w', encoding='utf-8')
for index, word in enumerate(vocab):
if index == 0:
continue # skip 0, it's padding.
vec = weights[index]
out_v.write('\t'.join([str(x) for x in vec]) + "\n")
out_m.write(word + "\n")
out_v.close()
out_m.close()
下载 vectors.tsv
和 metadata.tsv
,在 Embedding Projector 中分析得到的嵌入向量:
try:
from google.colab import files
files.download('vectors.tsv')
files.download('metadata.tsv')
except Exception:
pass
后续步骤
本教程向您展示了如何从头开始实现带有负采样的跳字 word2vec 模型并可视化获得的单词嵌入向量。
要了解有关单词向量及其数学表示的更多信息,请参阅这些注释。
要了解有关高级文本处理的更多信息,请阅读理解语言的 Transformer 模型教程。
如果您对预训练的嵌入向量模型感兴趣,您可能还会对探索 TF-Hub CORD-19 Swivel 嵌入向量或多语言通用句子编码器感兴趣。
您可能还想在新数据集上训练模型(TensorFlow Datasets 中有很多可用的数据集)。