TensorFlow Lite 中的签名

在 TensorFlow.org上查看 在 Google Colab 运行 在 GitHub 上查看源代码 下载笔记本

TensorFlow Lite 支持将 TensorFlow 模型的输入/输出规范转换为 TensorFlow Lite 模型。输入/输出规范称为“签名”。可以在构建 SavedModel 或创建具体函数时指定签名。

TensorFlow Lite 中的签名提供以下功能:

  • 它们根据 TensorFlow 模型的签名指定转换后的 TensorFlow Lite 模型的输入和输出。
  • 允许单个 TensorFlow Lite 模型支持多个入口点。

签名由三部分组成:

  • 输入:从签名中的输入名称到输入张量的输入映射。
  • 输出:从签名中的输出名称映射到输出张量的输出映射。
  • 签名键: 标识计算图入口点的名称。

安装

import tensorflow as tf
2023-11-07 21:43:41.308136: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-07 21:43:41.308189: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-07 21:43:41.309877: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

示例模型

假设我们有两个任务(例如编码和解码)作为 TensorFlow 模型:

class Model(tf.Module):

  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
  def encode(self, x):
    result = tf.strings.as_string(x)
    return {
         "encoded_result": result
    }

  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
  def decode(self, x):
    result = tf.strings.to_number(x)
    return {
         "decoded_result": result
    }

就签名而言,上面的 TensorFlow 模型可以总结如下:

  • 签名

    • 键:编码
    • 输入:{"x"}
    • 输出:{"encoded_result"}
  • 签名

    • 键:解码
    • 输入:{"x"}
    • 输出:{"decoded_result"}

转换带有签名的模型

TensorFlow Lite Converter API 会将上述签名信息带入转换后的 TensorFlow Lite 模型。

从 TensorFlow 2.7.0 版开始,所有 Converter API 都提供此转换功能。请参阅示例用法。

从保存的模型

model = Model()

# Save the model
SAVED_MODEL_PATH = 'content/saved_models/coding'

tf.saved_model.save(
    model, SAVED_MODEL_PATH,
    signatures={
      'encode': model.encode.get_concrete_function(),
      'decode': model.decode.get_concrete_function()
    })

# Convert the saved model using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
tflite_model = converter.convert()

# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)
INFO:tensorflow:Assets written to: content/saved_models/coding/assets
{'decode': {'inputs': ['x'], 'outputs': ['decoded_result']}, 'encode': {'inputs': ['x'], 'outputs': ['encoded_result']} }
2023-11-07 21:43:45.947099: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2023-11-07 21:43:45.947137: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2023-11-07 21:43:45.984585: E tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc:119] Unsupported data type.
Summary on the non-converted ops:
---------------------------------

 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 2, Total Ops 9, % non-converted = 22.22 %
 * 2 TF ops




- tf.AsString:    1 occurrences  (: 1)
- tf.StringToNumber:    1 occurrences  (f32: 1)
  (f32: 1)

2023-11-07 21:43:45.985471: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:2921] TFLite interpreter needs to link Flex delegate in order to run the model since it contains the following Select TFop(s):
Flex ops: FlexAsString, FlexStringToNumber
Details:
    tf.AsString(tensor<?xf32>) -> (tensor<?x!tf_type.string>) : {device = "", fill = "", precision = -1 : i64, scientific = false, shortest = false, width = -1 : i64}
    tf.StringToNumber(tensor<?x!tf_type.string>) -> (tensor<?xf32>) : {device = "", out_type = f32}
See instructions: https://www.tensorflow.org/lite/guide/ops_select
INFO: Created TensorFlow Lite delegate for select TF ops.
INFO: TfLiteFlexDelegate delegate: 1 nodes delegated out of 1 nodes with 1 partitions.

从 Keras 模型

# Generate a Keras model.
keras_model = tf.keras.Sequential(
    [
        tf.keras.layers.Dense(2, input_dim=4, activation='relu', name='x'),
        tf.keras.layers.Dense(1, activation='relu', name='output'),
    ]
)

# Convert the keras model using TFLiteConverter.
# Keras model converter API uses the default signature automatically.
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
tflite_model = converter.convert()

# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)

signatures = interpreter.get_signature_list()
print(signatures)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpsg7u7r9z/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpsg7u7r9z/assets
2023-11-07 21:43:46.560338: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2023-11-07 21:43:46.560380: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
{'serving_default': {'inputs': ['x_input'], 'outputs': ['output']} }
Summary on the non-converted ops:
---------------------------------

 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 2, Total Ops 8, % non-converted = 25.00 %
 * 2 ARITH ops

- arith.constant:    2 occurrences  (f32: 2)



  (f32: 2)

从具体函数

model = Model()

# Convert the concrete functions using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [model.encode.get_concrete_function(),
     model.decode.get_concrete_function()], model)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
tflite_model = converter.convert()

# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpxothj3r4/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpxothj3r4/assets
2023-11-07 21:43:46.760162: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2023-11-07 21:43:46.760205: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2023-11-07 21:43:46.794782: E tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc:119] Unsupported data type.
Summary on the non-converted ops:
---------------------------------

 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 2, Total Ops 9, % non-converted = 22.22 %
 * 2 TF ops




- tf.AsString:    1 occurrences  (: 1)
- tf.StringToNumber:    1 occurrences  (f32: 1)
  (f32: 1)

2023-11-07 21:43:46.795372: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:2921] TFLite interpreter needs to link Flex delegate in order to run the model since it contains the following Select TFop(s):
Flex ops: FlexAsString, FlexStringToNumber
Details:
    tf.AsString(tensor<?xf32>) -> (tensor<?x!tf_type.string>) : {device = "", fill = "", precision = -1 : i64, scientific = false, shortest = false, width = -1 : i64}
    tf.StringToNumber(tensor<?x!tf_type.string>) -> (tensor<?xf32>) : {device = "", out_type = f32}
See instructions: https://www.tensorflow.org/lite/guide/ops_select
{'decode': {'inputs': ['x'], 'outputs': ['decoded_result']}, 'encode': {'inputs': ['x'], 'outputs': ['encoded_result']} }

运行签名

TensorFlow 推断 API 支持基于签名的执行:

  • 通过输入和输出的名称(由签名指定)访问输入/输出张量。
  • 分别运行由签名键标识的计算图的每个入口点。
  • 支持 SavedModel 的初始化过程。

Java、C++ 和 Python 语言绑定目前可用。请参阅以下各部分的示例。

Java

try (Interpreter interpreter = new Interpreter(file_of_tensorflowlite_model)) {
  // Run encoding signature.
  Map&lt;String, Object&gt; inputs = new HashMap&lt;&gt;();
  inputs.put("x", input);
  Map&lt;String, Object&gt; outputs = new HashMap&lt;&gt;();
  outputs.put("encoded_result", encoded_result);
  interpreter.runSignature(inputs, outputs, "encode");

  // Run decoding signature.
  Map&lt;String, Object&gt; inputs = new HashMap&lt;&gt;();
  inputs.put("x", encoded_result);
  Map&lt;String, Object&gt; outputs = new HashMap&lt;&gt;();
  outputs.put("decoded_result", decoded_result);
  interpreter.runSignature(inputs, outputs, "decode");
}

C++

SignatureRunner* encode_runner =
    interpreter->GetSignatureRunner("encode");
encode_runner->ResizeInputTensor("x", {100});
encode_runner->AllocateTensors();

TfLiteTensor* input_tensor = encode_runner->input_tensor("x");
float* input = GetTensorData<float>(input_tensor);
// Fill `input`.

encode_runner->Invoke();

const TfLiteTensor* output_tensor = encode_runner->output_tensor(
    "encoded_result");
float* output = GetTensorData<float>(output_tensor);
// Access `output`.

Python

# Load the TFLite model in TFLite Interpreter
interpreter = tf.lite.Interpreter(model_content=tflite_model)

# Print the signatures from the converted model
signatures = interpreter.get_signature_list()
print('Signature:', signatures)

# encode and decode are callable with input as arguments.
encode = interpreter.get_signature_runner('encode')
decode = interpreter.get_signature_runner('decode')

# 'encoded' and 'decoded' are dictionaries with all outputs from the inference.
input = tf.constant([1, 2, 3], dtype=tf.float32)
print('Input:', input)
encoded = encode(x=input)
print('Encoded result:', encoded)
decoded = decode(x=encoded['encoded_result'])
print('Decoded result:', decoded)
Signature: {'decode': {'inputs': ['x'], 'outputs': ['decoded_result']}, 'encode': {'inputs': ['x'], 'outputs': ['encoded_result']} }
Input: tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32)
Encoded result: {'encoded_result': array([b'1.000000', b'2.000000', b'3.000000'], dtype=object)}
Decoded result: {'decoded_result': array([1., 2., 3.], dtype=float32)}

已知问题/限制

  • 由于 TFLite 解释器不能保证线程安全,所以来自同一解释器的签名运行程序不会被并发执行。
  • 目前尚不支持 C/iOS/Swift。

更新

  • 版本 2.7
    • 实现了多重签名功能。
    • 版本 2 中的所有 Converter API 都会生成启用签名的 TensorFlow Lite 模型。
  • 版本 2.5
    • 签名功能通过 from_saved_model Converter API 提供。