TensorFlow Lite Metadata Writer API

在 TensorFlow.org 上查看 在 Google Colab 运行 在 GitHub 上查看源代码 下载笔记本

TensorFlow Lite Model Metadata 是标准的模型描述格式。它包含了丰富的通用模型信息、输入/输出和相关文件的语义,这使得模型具有自描述性和可交换性。

Model Metadata 模型元数据目前用于以下两个主要用例:

  1. 使用 TensorFlow Lite Task LibraryCodegen 工具启用简单模型推断。Model Metadata 包含推断过程中必需的信息,如图像分类中的标签文件、音频分类中音频输入的采样率以及自然语言模型中用于处理输入字符串的标记器类型。

  2. 使模型创建者能够包括文档,例如模型输入/输出的说明或模型使用方式。模型用户可以通过可视化工具(如 Netron)查看这些文档。

TensorFlow Lite Metadata Writer API 提供了一个易用的 API 来为 TFLite Task Library 支持的常用机器学习任务创建模型元数据。本笔记本展示了应如何为以下任务填充元数据的示例:

适用于 BERT 自然语言分类器和 BERT 问答器的元数据编写器即将推出。

如果要为不受支持的用例添加元数据,请使用 Flatbuffers Python API。请参阅此处的教程。

先决条件

安装 TensorFlow Lite Support Pypi 软件包。

pip install tflite-support-nightly

为 Task Library 和 Codegen 创建模型元数据

图像分类器

有关支持的模型格式的更多详细信息,请参阅图像分类器模型兼容性要求

第 1 步:导入所需的软件包。

from tflite_support.metadata_writers import image_classifier
from tflite_support.metadata_writers import writer_utils

第 2 步:下载图像分类器示例,mobilenet_v2_1.0_224.tflite标签文件

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite -o mobilenet_v2_1.0_224.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt -o mobilenet_labels.txt

第 3 步:创建元数据编写器并填充。

ImageClassifierWriter = image_classifier.MetadataWriter
_MODEL_PATH = "mobilenet_v2_1.0_224.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "mobilenet_labels.txt"
_SAVE_TO_PATH = "mobilenet_v2_1.0_224_metadata.tflite"
# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://tensorflow.google.cn/lite/models/convert/metadata#normalization_and_quantization_parameters)
_INPUT_NORM_MEAN = 127.5
_INPUT_NORM_STD = 127.5

# Create the metadata writer.
writer = ImageClassifierWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
    [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

目标检测器

有关支持的模型格式的更多详细信息,请参阅目标检测器模型兼容性要求

第 1 步:导入所需的软件包。

from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils

第 2 步:下载示例目标检测器 ssd_mobilenet_v1.tflite标签文件

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/ssd_mobilenet_v1.tflite -o ssd_mobilenet_v1.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/labelmap.txt -o ssd_mobilenet_labels.txt

第 3 步:创建元数据编写器并填充。

ObjectDetectorWriter = object_detector.MetadataWriter
_MODEL_PATH = "ssd_mobilenet_v1.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "ssd_mobilenet_labels.txt"
_SAVE_TO_PATH = "ssd_mobilenet_v1_metadata.tflite"
# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://tensorflow.google.cn/lite/models/convert/metadata#normalization_and_quantization_parameters)
_INPUT_NORM_MEAN = 127.5
_INPUT_NORM_STD = 127.5

# Create the metadata writer.
writer = ObjectDetectorWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
    [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

图像分割器

有关支持的模型格式的更多详细信息,请参阅图像分割器模型兼容性要求

第 1 步:导入所需的软件包。

from tflite_support.metadata_writers import image_segmenter
from tflite_support.metadata_writers import writer_utils

第 2 步:下载示例图像分割器 deeplabv3.tflite标签文件

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/deeplabv3.tflite -o deeplabv3.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/labelmap.txt -o deeplabv3_labels.txt

第 3 步:创建元数据编写器并填充。

ImageSegmenterWriter = image_segmenter.MetadataWriter
_MODEL_PATH = "deeplabv3.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "deeplabv3_labels.txt"
_SAVE_TO_PATH = "deeplabv3_metadata.tflite"
# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://tensorflow.google.cn/lite/models/convert/metadata#normalization_and_quantization_parameters)
_INPUT_NORM_MEAN = 127.5
_INPUT_NORM_STD = 127.5

# Create the metadata writer.
writer = ImageSegmenterWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
    [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

###自然语言分类器

有关支持的模型格式的更多详细信息,请参阅自然语言分类器模型兼容性要求

第 1 步:导入所需的软件包。

from tflite_support.metadata_writers import nl_classifier
from tflite_support.metadata_writers import metadata_info
from tflite_support.metadata_writers import writer_utils

第 2 步:下载示例自然语言分类器 movie_review.tflite标签文件词汇文件

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/movie_review.tflite -o movie_review.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/labels.txt -o movie_review_labels.txt
curl -L https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/nl_classifier/vocab.txt -o movie_review_vocab.txt

第 3 步:创建元数据编写器并填充。

NLClassifierWriter = nl_classifier.MetadataWriter
_MODEL_PATH = "movie_review.tflite"
# Task Library expects label files and vocab files that are in the same formats
# as the ones below.
_LABEL_FILE = "movie_review_labels.txt"
_VOCAB_FILE = "movie_review_vocab.txt"
# NLClassifier supports tokenize input string using the regex tokenizer. See
# more details about how to set up RegexTokenizer below:
# https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/metadata_writers/metadata_info.py#L130
_DELIM_REGEX_PATTERN = r"[^\w\']+"
_SAVE_TO_PATH = "moview_review_metadata.tflite"

# Create the metadata writer.
writer = nl_classifier.MetadataWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH),
    metadata_info.RegexTokenizerMd(_DELIM_REGEX_PATTERN, _VOCAB_FILE),
    [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

音频分类器

有关支持的模型格式的更多详细信息,请参阅音频分类器模型兼容性要求

第 1 步:导入所需的软件包。

from tflite_support.metadata_writers import audio_classifier
from tflite_support.metadata_writers import metadata_info
from tflite_support.metadata_writers import writer_utils

第 2 步:下载示例音频分类器 yamnet.tflite标签文件

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6.tflite -o yamnet.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_521_labels.txt -o yamnet_labels.txt

第 3 步:创建元数据编写器并填充。

AudioClassifierWriter = audio_classifier.MetadataWriter
_MODEL_PATH = "yamnet.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "yamnet_labels.txt"
# Expected sampling rate of the input audio buffer.
_SAMPLE_RATE = 16000
# Expected number of channels of the input audio buffer. Note, Task library only
# support single channel so far.
_CHANNELS = 1
_SAVE_TO_PATH = "yamnet_metadata.tflite"

# Create the metadata writer.
writer = AudioClassifierWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), _SAMPLE_RATE, _CHANNELS, [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

创建具有语义信息的模型元数据

您可以通过 Metadata Writer API 填写有关模型和每个张量的更多描述性信息,以帮助提高对模型的理解。它可以通过每个元数据编写器中的 'create_from_metadata_info' 方法来完成。通常,您可以通过 'create_from_metadata_info' 的参数填写数据,即 general_mdinput_mdoutput_md。请参阅下面的示例,为图像分类器创建丰富的模型元数据。

第 1 步:导入所需的软件包。

from tflite_support.metadata_writers import image_classifier
from tflite_support.metadata_writers import metadata_info
from tflite_support.metadata_writers import writer_utils
from tflite_support import metadata_schema_py_generated as _metadata_fb

第 2 步:下载图像分类器示例,mobilenet_v2_1.0_224.tflite标签文件

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite -o mobilenet_v2_1.0_224.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt -o mobilenet_labels.txt

第 3 步:创建模型和张量信息。

model_buffer = writer_utils.load_file("mobilenet_v2_1.0_224.tflite")

# Create general model information.
general_md = metadata_info.GeneralMd(
    name="ImageClassifier",
    version="v1",
    description=("Identify the most prominent object in the image from a "
                 "known set of categories."),
    author="TensorFlow Lite",
    licenses="Apache License. Version 2.0")

# Create input tensor information.
input_md = metadata_info.InputImageTensorMd(
    name="input image",
    description=("Input image to be classified. The expected image is "
                 "128 x 128, with three channels (red, blue, and green) per "
                 "pixel. Each element in the tensor is a value between min and "
                 "max, where (per-channel) min is [0] and max is [255]."),
    norm_mean=[127.5],
    norm_std=[127.5],
    color_space_type=_metadata_fb.ColorSpaceType.RGB,
    tensor_type=writer_utils.get_input_tensor_types(model_buffer)[0])

# Create output tensor information.
output_md = metadata_info.ClassificationTensorMd(
    name="probability",
    description="Probabilities of the 1001 labels respectively.",
    label_files=[
        metadata_info.LabelFileMd(file_path="mobilenet_labels.txt",
                                  locale="en")
    ],
    tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0])

第 4 步:创建元数据编写器并填充。

ImageClassifierWriter = image_classifier.MetadataWriter
# Create the metadata writer.
writer = ImageClassifierWriter.create_from_metadata_info(
    model_buffer, general_md, input_md, output_md)

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

读取填充到模型中的元数据

您可以通过以下代码在 TFLite 模型中显示元数据和关联文件:

from tflite_support import metadata

displayer = metadata.MetadataDisplayer.with_model_file("mobilenet_v2_1.0_224_metadata.tflite")
print("Metadata populated:")
print(displayer.get_metadata_json())

print("Associated file(s) populated:")
for file_name in displayer.get_packed_associated_file_list():
  print("file name: ", file_name)
  print("file content:")
  print(displayer.get_associated_file_buffer(file_name))