使用 TensorFlow Lite Model Maker 的文本搜索器

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本 查看 TF Hub 模型

在此 CoLab 笔记本中,您可以学习如何使用 TensorFlow Lite Model Maker 库来创建 TFLite Searcher 模型。您可以使用文本 Searcher 模型为您的应用构建语义搜索或智能回复。这种类型的模型允许您进行文本查询,并在文本数据集(例如网页数据库)中搜索最相关的条目。该模型会返回数据集中最小距离得分条目的列表,包括您指定的元数据,如网址、页面标题或其他文本条目标识符。构建后,您可以使用 Task Library Searcher API 将其部署到设备(例如 Android)上,只需几行代码即可运行推断。

本教程利用 CNN/DailyMail 数据集作为实例来创建 TFLite Searcher 模型。您可以尝试使用兼容的输入逗号分隔值 (CSV) 格式的您自己的数据集。

使用可扩缩最近邻的文本搜索

本教程使用公开提供的 CNN/DailyMail 非匿名摘要数据集,该数据集从 GitHub 仓库生成。该数据集包含超过 30 万篇新闻文章,这使得它成为用于构建 Searcher 模型很好的数据集,并且会在模型推断过程中返回各种相关新闻进行文本查询。

本示例中的文本 Searcher 模型使用了一个 ScaNN(可扩缩最近邻居)索引文件,该文件可以从预定义的数据库中搜索相似的项目。ScaNN 实现了最先进的性能,实现了大规模高效的矢量相似度搜索。

此 CoLab 使用此数据集中的突出显示内容和网址来创建模型:

  1. 突出显示的是用于生成嵌入特征向量并随后用于搜索的文本。
  2. 网址是搜索相关突出显示内容后返回给用户的结果。

本教程会将这些数据保存到 CSV 文件中,然后使用 CSV 文件构建模型。以下是数据集中的几个示例。

突出显示 网址
Hawaiian Airlines again lands at No. 1 in on-time performance. The Airline Quality Rankings Report looks at the 14 largest U.S. airlines. ExpressJet
and American Airlines had the worst on-time performance. Virgin America had the best baggage handling; Southwest had lowest complaint rate.
http://www.cnn.com/2013/04/08/travel/airline-quality-report
European football's governing body reveals list of countries bidding to host 2020 finals. The 60th anniversary edition of the finals will be hosted by 13
countries. Thirty-two countries are considering bids to host 2020 matches. UEFA will announce host cities on September 25.
http://edition.cnn.com:80/2013/09/20/sport/football/football-euro-2020-bid-countries/index.html?
Once octopus-hunter Dylan Mayer has now also signed a petition of 5,000 divers banning their hunt at Seacrest Park. Decision by Washington
Department of Fish and Wildlife could take months.
http://www.dailymail.co.uk:80/news/article-2238423/Dylan-Mayer-Washington-considers-ban-Octopus-hunting-diver-caught-ate-Puget-Sound.html?
Galaxy was observed 420 million years after the Big Bang. found by NASA’s Hubble Space Telescope, Spitzer Space Telescope, and one of nature’s
own natural 'zoom lenses' in space.
http://www.dailymail.co.uk/sciencetech/article-2233883/The-furthest-object-seen-Record-breaking-image-shows-galaxy-13-3-BILLION-light-years-Earth.html

安装

首先安装所需的软件包,包括来自 GitHub 仓库的 Model Maker 软件包。

sudo apt -y install libportaudio2
pip install -q tflite-model-maker
pip install gdown

导入所需的软件包。

from tflite_model_maker import searcher

准备数据集

本教程使用来自 GitHub 仓库的数据集 CNN/Daily Mail 摘要数据集。

首先,下载 CNN 和 Daily Mail 的文本和网址并解压缩。如果无法从 Google Drive 下载,请等待几分钟重试,或手动下载,然后上传到 CoLab。

gdown https://drive.google.com/uc?id=0BwmD_VLjROrfTHk4NFg2SndKcjQ
gdown https://drive.google.com/uc?id=0BwmD_VLjROrfM1BxdkxVaTY2bWs

wget -O all_train.txt https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt
tar xzf cnn_stories.tgz
tar xzf dailymail_stories.tgz
Downloading...
From: https://drive.google.com/uc?id=0BwmD_VLjROrfTHk4NFg2SndKcjQ
To: /tmpfs/src/temp/site/zh-cn/lite/models/modify/model_maker/cnn_stories.tgz
100%|█████████████████████████████████████████| 159M/159M [00:00<00:00, 167MB/s]
Downloading...
From: https://drive.google.com/uc?id=0BwmD_VLjROrfM1BxdkxVaTY2bWs
To: /tmpfs/src/temp/site/zh-cn/lite/models/modify/model_maker/dailymail_stories.tgz
100%|█████████████████████████████████████████| 376M/376M [00:03<00:00, 121MB/s]
--2022-08-31 00:10:12--  https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 46424688 (44M) [text/plain]
Saving to: ‘all_train.txt’

all_train.txt       100%[===================>]  44.27M   202MB/s    in 0.2s    

2022-08-31 00:10:13 (202 MB/s) - ‘all_train.txt’ saved [46424688/46424688]

然后,将数据保存到 CSV 文件中,该文件可以加载到 tflite_model_maker 库中。代码基于用于在 tensorflow_datasets 中加载此数据的逻辑。我们不能直接使用 tensorflow_dataset,因为它不包含此 CoLab 中使用的网址。

因为将数据处理成嵌入整个数据集的特征向量需要很长时间。默认情况下,仅选择 CNN 和 Daily Mail 数据集的前 5% 的新闻用于演示目的。您可以调整分数或尝试使用预构建的 TFLite 模型,其中包含 50% 的 CNN 和 Daily Mail 数据集以供搜索。

Save the highlights and urls to the CSV file

CNN size: 4513
Daily Mail size: 9848

构建文本 Searcher 模型

通过加载数据集、使用数据创建模型并导出 TFLite 模型来创建文本 Searcher 模型。

第 1 步:加载数据集

Model Maker 获取 CSV 格式的文本数据集和每个文本字符串的相应元数据(如本例中的网址)。它使用用户指定的嵌入器模型将文本字符串嵌入到特征向量中。

在本演示中,我们使用通用句子编码器构建 Searcher 模型,这是一种最先进的句子嵌入向量模型,它已经从 CoLab 重新训练。该模型针对设备上的推断性能进行了优化,嵌入一个查询字符串只需 6ms(在 Pixel 6 上测得)。或者,您也可以使用这个量化版本,该版本更小,但每次嵌入需要 38ms。

wget -O universal_sentence_encoder.tflite https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/searcher/text_to_image_blogpost/text_embedder.tflite
--2022-08-31 00:10:31--  https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/searcher/text_to_image_blogpost/text_embedder.tflite
Resolving storage.googleapis.com (storage.googleapis.com)... 209.85.146.128, 209.85.147.128, 142.250.125.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|209.85.146.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 28691620 (27M) [application/octet-stream]
Saving to: ‘universal_sentence_encoder.tflite’

universal_sentence_ 100%[===================>]  27.36M   124MB/s    in 0.2s    

2022-08-31 00:10:32 (124 MB/s) - ‘universal_sentence_encoder.tflite’ saved [28691620/28691620]

创建一个 searcher.TextDataLoader 实例并使用 data_loader.load_from_csv 方法加载数据集。该步骤需要约 10 分钟,因为它会逐个为每个文本生成嵌入特征向量。您可以尝试上传并加载您自己的 CSV 文件,以构建自定义模型。

指定 CSV 文件中的文本列和元数据列的名称。

  • 利用文本生成嵌入特征向量。
  • 元数据是搜索特定文本时要显示的内容。

以下是上面生成的 CNN-DailyMail CSV 文件的前 4 行。

突出显示 网址
Syrian official: Obama climbed to the top of the tree, doesn't know how to get down. Obama sends a letter to the heads of the House and Senate. Obama
to seek congressional approval on military action against Syria. Aim is to determine whether CW were used, not by whom, says U.N. spokesman.
http://www.cnn.com/2013/08/31/world/meast/syria-civil-war/
Usain Bolt wins third gold of world championship. Anchors Jamaica to 4x100m relay victory. Eighth gold at the championships for Bolt. Jamaica double
up in women's 4x100m relay.
http://edition.cnn.com/2013/08/18/sport/athletics-bolt-jamaica-gold
The employee in agency's Kansas City office is among hundreds of "virtual" workers. The employee's travel to and from the mainland U.S. last year cost
more than $24,000. The telecommuting program, like all GSA practices, is under review.
http://www.cnn.com:80/2012/08/23/politics/gsa-hawaii-teleworking
NEW: A Canadian doctor says she was part of a team examining Harry Burkhart in 2010. NEW: Diagnosis: "autism, severe anxiety, post-traumatic stress
disorder and depression" Burkhart is also suspected in a German arson probe, officials say. Prosecutors believe the German national set a string of fires
in Los Angeles.
http://edition.cnn.com:80/2012/01/05/justice/california-arson/index.html?
data_loader = searcher.TextDataLoader.create("universal_sentence_encoder.tflite", l2_normalize=True)
data_loader.load_from_csv("cnn_dailymail.csv", text_column="highlights", metadata_column="urls")
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.

对于图片用例,您可以创建一个 searcher.ImageDataLoader 实例,然后使用 data_loader.load_from_folder 从该文件夹加载图片。searcher.ImageDataLoader 实例需要由 TFLite 嵌入器模型创建,因为它将用于将查询编码为特征向量,并与 TFLite Searcher 模型一起导出。例如:

data_loader = searcher.ImageDataLoader.create("mobilenet_v2_035_96_embedder_with_metadata.tflite")
data_loader.load_from_folder("food/")

第 2 步:创建 Searcher 模型

  • 配置扫描选项。请参阅 API 文档了解详细信息。
  • 从数据和 ScaNN 选项创建 Searcher 模型。您可以参阅深度检查,了解有关 ScaNN 算法的更多信息。
scann_options = searcher.ScaNNOptions(
      distance_measure="dot_product",
      tree=searcher.Tree(num_leaves=140, num_leaves_to_search=4),
      score_ah=searcher.ScoreAH(dimensions_per_block=1, anisotropic_quantization_threshold=0.2))
model = searcher.Searcher.create_from_data(data_loader, scann_options)
[libprotobuf WARNING external/com_google_protobuf/src/google/protobuf/text_format.cc:339] Warning parsing text-format research_scann.ScannConfig: 38:5: text format contains deprecated field "min_cluster_size"

在上面的示例中,我们定义了以下选项:

  • distance_measure:我们使用 "dot_product" 来衡量两个嵌入向量之间的距离。请注意,我们实际上计算的是点积值,以保持“越小越近”的概念。

  • tree:数据集被划分为 140 个分区(大致是数据大小的平方根),在检索过程中会搜索其中 4 个分区,约占数据集的 3%。

  • score_ah:为了节省空间,我们将浮点嵌入向量量化为相同维度的 int8 值。

第 3 步:导出 TFLite 模型

然后,您可以导出 TFLite Searcher 模型。

model.export(
      export_filename="searcher.tflite",
      userinfo="",
      export_format=searcher.ExportFormat.TFLITE)

在查询中测试 TFLite 模型

您可以使用自定义查询文本测试导出的 TFLite 模型。要使用 Searcher 模型查询文本,请初始化该模型并使用文本短语运行搜索,如下所示:

from tflite_support.task import text

# Initializes a TextSearcher object.
searcher = text.TextSearcher.create_from_file("searcher.tflite")

# Searches the input query.
results = searcher.search("The Airline Quality Rankings Report looks at the 14 largest U.S. airlines.")
print(results)
SearchResult(nearest_neighbors=[NearestNeighbor(metadata=bytearray(b'http://www.dailymail.co.uk/news/article-2599185/U-S-airlines-post-best-ratings-flights-late-bags-mishandled.html'), distance=-0.8651008605957031), NearestNeighbor(metadata=bytearray(b'http://www.cnn.com/2014/05/05/travel/airline-baggage-change-fees/'), distance=-0.8364017009735107), NearestNeighbor(metadata=bytearray(b'http://www.cnn.com:80/2012/10/12/travel/american-airlines-flight-cuts/?'), distance=-0.795403242111206), NearestNeighbor(metadata=bytearray(b'http://www.cnn.com/2013/04/19/travel/faa-furloughs'), distance=-0.7902786731719971), NearestNeighbor(metadata=bytearray(b'http://www.cnn.com:80/2014/01/09/travel/safest-airline-2013'), distance=-0.768242359161377)])
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.

有关如何将模型集成到各种平台的详细信息,请参阅 Task Library 文档

阅读更多

有关更多信息,请参阅: