在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 Github 上查看源代码 | 下载笔记本 | 查看 TF Hub 模型 |
在此 CoLab 笔记本中,您可以学习如何使用 TensorFlow Lite Model Maker 库来创建 TFLite Searcher 模型。您可以使用文本 Searcher 模型为您的应用构建语义搜索或智能回复。这种类型的模型允许您进行文本查询,并在文本数据集(例如网页数据库)中搜索最相关的条目。该模型会返回数据集中最小距离得分条目的列表,包括您指定的元数据,如网址、页面标题或其他文本条目标识符。构建后,您可以使用 Task Library Searcher API 将其部署到设备(例如 Android)上,只需几行代码即可运行推断。
本教程利用 CNN/DailyMail 数据集作为实例来创建 TFLite Searcher 模型。您可以尝试使用兼容的输入逗号分隔值 (CSV) 格式的您自己的数据集。
使用可扩缩最近邻的文本搜索
本教程使用公开提供的 CNN/DailyMail 非匿名摘要数据集,该数据集从 GitHub 仓库生成。该数据集包含超过 30 万篇新闻文章,这使得它成为用于构建 Searcher 模型很好的数据集,并且会在模型推断过程中返回各种相关新闻进行文本查询。
本示例中的文本 Searcher 模型使用了一个 ScaNN(可扩缩最近邻居)索引文件,该文件可以从预定义的数据库中搜索相似的项目。ScaNN 实现了最先进的性能,实现了大规模高效的矢量相似度搜索。
此 CoLab 使用此数据集中的突出显示内容和网址来创建模型:
- 突出显示的是用于生成嵌入特征向量并随后用于搜索的文本。
- 网址是搜索相关突出显示内容后返回给用户的结果。
本教程会将这些数据保存到 CSV 文件中,然后使用 CSV 文件构建模型。以下是数据集中的几个示例。
突出显示 | 网址 |
---|---|
Hawaiian Airlines again lands at No. 1 in on-time performance. The Airline Quality Rankings Report looks at the 14 largest U.S. airlines. ExpressJet and American Airlines had the worst on-time performance. Virgin America had the best baggage handling; Southwest had lowest complaint rate. |
http://www.cnn.com/2013/04/08/travel/airline-quality-report |
European football's governing body reveals list of countries bidding to host 2020 finals. The 60th anniversary edition of the finals will be hosted by 13 countries. Thirty-two countries are considering bids to host 2020 matches. UEFA will announce host cities on September 25. |
http://edition.cnn.com:80/2013/09/20/sport/football/football-euro-2020-bid-countries/index.html? |
Once octopus-hunter Dylan Mayer has now also signed a petition of 5,000 divers banning their hunt at Seacrest Park. Decision by Washington Department of Fish and Wildlife could take months. |
http://www.dailymail.co.uk:80/news/article-2238423/Dylan-Mayer-Washington-considers-ban-Octopus-hunting-diver-caught-ate-Puget-Sound.html? |
Galaxy was observed 420 million years after the Big Bang. found by NASA’s Hubble Space Telescope, Spitzer Space Telescope, and one of nature’s own natural 'zoom lenses' in space. |
http://www.dailymail.co.uk/sciencetech/article-2233883/The-furthest-object-seen-Record-breaking-image-shows-galaxy-13-3-BILLION-light-years-Earth.html |
安装
首先安装所需的软件包,包括来自 GitHub 仓库的 Model Maker 软件包。
sudo apt -y install libportaudio2
pip install -q tflite-model-maker
pip install gdown
导入所需的软件包。
from tflite_model_maker import searcher
准备数据集
本教程使用来自 GitHub 仓库的数据集 CNN/Daily Mail 摘要数据集。
首先,下载 CNN 和 Daily Mail 的文本和网址并解压缩。如果无法从 Google Drive 下载,请等待几分钟重试,或手动下载,然后上传到 CoLab。
gdown https://drive.google.com/uc?id=0BwmD_VLjROrfTHk4NFg2SndKcjQ
gdown https://drive.google.com/uc?id=0BwmD_VLjROrfM1BxdkxVaTY2bWs
wget -O all_train.txt https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt
tar xzf cnn_stories.tgz
tar xzf dailymail_stories.tgz
Downloading... From: https://drive.google.com/uc?id=0BwmD_VLjROrfTHk4NFg2SndKcjQ To: /tmpfs/src/temp/site/zh-cn/lite/models/modify/model_maker/cnn_stories.tgz 100%|█████████████████████████████████████████| 159M/159M [00:00<00:00, 167MB/s] Downloading... From: https://drive.google.com/uc?id=0BwmD_VLjROrfM1BxdkxVaTY2bWs To: /tmpfs/src/temp/site/zh-cn/lite/models/modify/model_maker/dailymail_stories.tgz 100%|█████████████████████████████████████████| 376M/376M [00:03<00:00, 121MB/s] --2022-08-31 00:10:12-- https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 46424688 (44M) [text/plain] Saving to: ‘all_train.txt’ all_train.txt 100%[===================>] 44.27M 202MB/s in 0.2s 2022-08-31 00:10:13 (202 MB/s) - ‘all_train.txt’ saved [46424688/46424688]
然后,将数据保存到 CSV 文件中,该文件可以加载到 tflite_model_maker
库中。代码基于用于在 tensorflow_datasets
中加载此数据的逻辑。我们不能直接使用 tensorflow_dataset
,因为它不包含此 CoLab 中使用的网址。
因为将数据处理成嵌入整个数据集的特征向量需要很长时间。默认情况下,仅选择 CNN 和 Daily Mail 数据集的前 5% 的新闻用于演示目的。您可以调整分数或尝试使用预构建的 TFLite 模型,其中包含 50% 的 CNN 和 Daily Mail 数据集以供搜索。
Save the highlights and urls to the CSV file
CNN_FRACTION = 0.05
DAILYMAIL_FRACTION = 0.05
import csv
import hashlib
import os
import tensorflow as tf
dm_single_close_quote = u"\u2019" # unicode
dm_double_close_quote = u"\u201d"
END_TOKENS = [
".", "!", "?", "...", "'", "`", '"', dm_single_close_quote,
dm_double_close_quote, ")"
] # acceptable ways to end a sentence
def read_file(file_path):
"""Reads lines in the file."""
lines = []
with tf.io.gfile.GFile(file_path, "r") as f:
for line in f:
lines.append(line.strip())
return lines
def url_hash(url):
"""Gets the hash value of the url."""
h = hashlib.sha1()
url = url.encode("utf-8")
h.update(url)
return h.hexdigest()
def get_url_hashes_dict(urls_path):
"""Gets hashes dict that maps the hash value to the original url in file."""
urls = read_file(urls_path)
return {url_hash(url): url[url.find("id_/") + 4:] for url in urls}
def find_files(folder, url_dict):
"""Finds files corresponding to the urls in the folder."""
all_files = tf.io.gfile.listdir(folder)
ret_files = []
for file in all_files:
# Gets the file name without extension.
filename = os.path.splitext(os.path.basename(file))[0]
if filename in url_dict:
ret_files.append(os.path.join(folder, file))
return ret_files
def fix_missing_period(line):
"""Adds a period to a line that is missing a period."""
if "@highlight" in line:
return line
if not line:
return line
if line[-1] in END_TOKENS:
return line
return line + "."
def get_highlights(story_file):
"""Gets highlights from a story file path."""
lines = read_file(story_file)
# Put periods on the ends of lines that are missing them
# (this is a problem in the dataset because many image captions don't end in
# periods; consequently they end up in the body of the article as run-on
# sentences)
lines = [fix_missing_period(line) for line in lines]
# Separate out article and abstract sentences
highlight_list = []
next_is_highlight = False
for line in lines:
if not line:
continue # empty line
elif line.startswith("@highlight"):
next_is_highlight = True
elif next_is_highlight:
highlight_list.append(line)
# Make highlights into a single string.
highlights = "\n".join(highlight_list)
return highlights
url_hashes_dict = get_url_hashes_dict("all_train.txt")
cnn_files = find_files("cnn/stories", url_hashes_dict)
dailymail_files = find_files("dailymail/stories", url_hashes_dict)
# The size to be selected.
cnn_size = int(CNN_FRACTION * len(cnn_files))
dailymail_size = int(DAILYMAIL_FRACTION * len(dailymail_files))
print("CNN size: %d"%cnn_size)
print("Daily Mail size: %d"%dailymail_size)
with open("cnn_dailymail.csv", "w") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=["highlights", "urls"])
writer.writeheader()
for file in cnn_files[:cnn_size] + dailymail_files[:dailymail_size]:
highlights = get_highlights(file)
# Gets the filename which is the hash value of the url.
filename = os.path.splitext(os.path.basename(file))[0]
url = url_hashes_dict[filename]
writer.writerow({"highlights": highlights, "urls": url})
CNN size: 4513 Daily Mail size: 9848
构建文本 Searcher 模型
通过加载数据集、使用数据创建模型并导出 TFLite 模型来创建文本 Searcher 模型。
第 1 步:加载数据集
Model Maker 获取 CSV 格式的文本数据集和每个文本字符串的相应元数据(如本例中的网址)。它使用用户指定的嵌入器模型将文本字符串嵌入到特征向量中。
在本演示中,我们使用通用句子编码器构建 Searcher 模型,这是一种最先进的句子嵌入向量模型,它已经从 CoLab 重新训练。该模型针对设备上的推断性能进行了优化,嵌入一个查询字符串只需 6ms(在 Pixel 6 上测得)。或者,您也可以使用这个量化版本,该版本更小,但每次嵌入需要 38ms。
wget -O universal_sentence_encoder.tflite https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/searcher/text_to_image_blogpost/text_embedder.tflite
--2022-08-31 00:10:31-- https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/searcher/text_to_image_blogpost/text_embedder.tflite Resolving storage.googleapis.com (storage.googleapis.com)... 209.85.146.128, 209.85.147.128, 142.250.125.128, ... Connecting to storage.googleapis.com (storage.googleapis.com)|209.85.146.128|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 28691620 (27M) [application/octet-stream] Saving to: ‘universal_sentence_encoder.tflite’ universal_sentence_ 100%[===================>] 27.36M 124MB/s in 0.2s 2022-08-31 00:10:32 (124 MB/s) - ‘universal_sentence_encoder.tflite’ saved [28691620/28691620]
创建一个 searcher.TextDataLoader
实例并使用 data_loader.load_from_csv
方法加载数据集。该步骤需要约 10 分钟,因为它会逐个为每个文本生成嵌入特征向量。您可以尝试上传并加载您自己的 CSV 文件,以构建自定义模型。
指定 CSV 文件中的文本列和元数据列的名称。
- 利用文本生成嵌入特征向量。
- 元数据是搜索特定文本时要显示的内容。
以下是上面生成的 CNN-DailyMail CSV 文件的前 4 行。
突出显示 | 网址 |
---|---|
Syrian official: Obama climbed to the top of the tree, doesn't know how to get down. Obama sends a letter to the heads of the House and Senate. Obama to seek congressional approval on military action against Syria. Aim is to determine whether CW were used, not by whom, says U.N. spokesman. |
http://www.cnn.com/2013/08/31/world/meast/syria-civil-war/ |
Usain Bolt wins third gold of world championship. Anchors Jamaica to 4x100m relay victory. Eighth gold at the championships for Bolt. Jamaica double up in women's 4x100m relay. |
http://edition.cnn.com/2013/08/18/sport/athletics-bolt-jamaica-gold |
The employee in agency's Kansas City office is among hundreds of "virtual" workers. The employee's travel to and from the mainland U.S. last year cost more than $24,000. The telecommuting program, like all GSA practices, is under review. |
http://www.cnn.com:80/2012/08/23/politics/gsa-hawaii-teleworking |
NEW: A Canadian doctor says she was part of a team examining Harry Burkhart in 2010. NEW: Diagnosis: "autism, severe anxiety, post-traumatic stress disorder and depression" Burkhart is also suspected in a German arson probe, officials say. Prosecutors believe the German national set a string of fires in Los Angeles. |
http://edition.cnn.com:80/2012/01/05/justice/california-arson/index.html? |
data_loader = searcher.TextDataLoader.create("universal_sentence_encoder.tflite", l2_normalize=True)
data_loader.load_from_csv("cnn_dailymail.csv", text_column="highlights", metadata_column="urls")
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
对于图片用例,您可以创建一个 searcher.ImageDataLoader
实例,然后使用 data_loader.load_from_folder
从该文件夹加载图片。searcher.ImageDataLoader
实例需要由 TFLite 嵌入器模型创建,因为它将用于将查询编码为特征向量,并与 TFLite Searcher 模型一起导出。例如:
data_loader = searcher.ImageDataLoader.create("mobilenet_v2_035_96_embedder_with_metadata.tflite")
data_loader.load_from_folder("food/")
第 2 步:创建 Searcher 模型
scann_options = searcher.ScaNNOptions(
distance_measure="dot_product",
tree=searcher.Tree(num_leaves=140, num_leaves_to_search=4),
score_ah=searcher.ScoreAH(dimensions_per_block=1, anisotropic_quantization_threshold=0.2))
model = searcher.Searcher.create_from_data(data_loader, scann_options)
[libprotobuf WARNING external/com_google_protobuf/src/google/protobuf/text_format.cc:339] Warning parsing text-format research_scann.ScannConfig: 38:5: text format contains deprecated field "min_cluster_size"
在上面的示例中,我们定义了以下选项:
distance_measure
:我们使用 "dot_product" 来衡量两个嵌入向量之间的距离。请注意,我们实际上计算的是负点积值,以保持“越小越近”的概念。tree
:数据集被划分为 140 个分区(大致是数据大小的平方根),在检索过程中会搜索其中 4 个分区,约占数据集的 3%。score_ah
:为了节省空间,我们将浮点嵌入向量量化为相同维度的 int8 值。
第 3 步:导出 TFLite 模型
然后,您可以导出 TFLite Searcher 模型。
model.export(
export_filename="searcher.tflite",
userinfo="",
export_format=searcher.ExportFormat.TFLITE)
在查询中测试 TFLite 模型
您可以使用自定义查询文本测试导出的 TFLite 模型。要使用 Searcher 模型查询文本,请初始化该模型并使用文本短语运行搜索,如下所示:
from tflite_support.task import text
# Initializes a TextSearcher object.
searcher = text.TextSearcher.create_from_file("searcher.tflite")
# Searches the input query.
results = searcher.search("The Airline Quality Rankings Report looks at the 14 largest U.S. airlines.")
print(results)
SearchResult(nearest_neighbors=[NearestNeighbor(metadata=bytearray(b'http://www.dailymail.co.uk/news/article-2599185/U-S-airlines-post-best-ratings-flights-late-bags-mishandled.html'), distance=-0.8651008605957031), NearestNeighbor(metadata=bytearray(b'http://www.cnn.com/2014/05/05/travel/airline-baggage-change-fees/'), distance=-0.8364017009735107), NearestNeighbor(metadata=bytearray(b'http://www.cnn.com:80/2012/10/12/travel/american-airlines-flight-cuts/?'), distance=-0.795403242111206), NearestNeighbor(metadata=bytearray(b'http://www.cnn.com/2013/04/19/travel/faa-furloughs'), distance=-0.7902786731719971), NearestNeighbor(metadata=bytearray(b'http://www.cnn.com:80/2014/01/09/travel/safest-airline-2013'), distance=-0.768242359161377)]) INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
有关如何将模型集成到各种平台的详细信息,请参阅 Task Library 文档。
阅读更多
有关更多信息,请参阅:
Task Library:用于开发的 TextSearcher。
端到端参考应用:Android。